xref: /aosp_15_r20/external/swiftshader/third_party/llvm-16.0/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (revision 03ce13f70fcc45d86ee91b7ee4cab1936a95046e)
1 //===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass combines dag nodes to form fewer, simpler DAG nodes.  It can be run
10 // both before and after the DAG is legalized.
11 //
12 // This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13 // primarily intended to handle simplification opportunities that are implicit
14 // in the LLVM IR and exposed by the various codegen lowering phases.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/IntervalMap.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallSet.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/Analysis/AliasAnalysis.h"
31 #include "llvm/Analysis/MemoryLocation.h"
32 #include "llvm/Analysis/TargetLibraryInfo.h"
33 #include "llvm/Analysis/VectorUtils.h"
34 #include "llvm/CodeGen/DAGCombine.h"
35 #include "llvm/CodeGen/ISDOpcodes.h"
36 #include "llvm/CodeGen/MachineFunction.h"
37 #include "llvm/CodeGen/MachineMemOperand.h"
38 #include "llvm/CodeGen/RuntimeLibcalls.h"
39 #include "llvm/CodeGen/SelectionDAG.h"
40 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
41 #include "llvm/CodeGen/SelectionDAGNodes.h"
42 #include "llvm/CodeGen/SelectionDAGTargetInfo.h"
43 #include "llvm/CodeGen/TargetLowering.h"
44 #include "llvm/CodeGen/TargetRegisterInfo.h"
45 #include "llvm/CodeGen/TargetSubtargetInfo.h"
46 #include "llvm/CodeGen/ValueTypes.h"
47 #include "llvm/IR/Attributes.h"
48 #include "llvm/IR/Constant.h"
49 #include "llvm/IR/DataLayout.h"
50 #include "llvm/IR/DerivedTypes.h"
51 #include "llvm/IR/Function.h"
52 #include "llvm/IR/Metadata.h"
53 #include "llvm/Support/Casting.h"
54 #include "llvm/Support/CodeGen.h"
55 #include "llvm/Support/CommandLine.h"
56 #include "llvm/Support/Compiler.h"
57 #include "llvm/Support/Debug.h"
58 #include "llvm/Support/ErrorHandling.h"
59 #include "llvm/Support/KnownBits.h"
60 #include "llvm/Support/MachineValueType.h"
61 #include "llvm/Support/MathExtras.h"
62 #include "llvm/Support/raw_ostream.h"
63 #include "llvm/Target/TargetMachine.h"
64 #include "llvm/Target/TargetOptions.h"
65 #include <algorithm>
66 #include <cassert>
67 #include <cstdint>
68 #include <functional>
69 #include <iterator>
70 #include <optional>
71 #include <string>
72 #include <tuple>
73 #include <utility>
74 #include <variant>
75 
76 using namespace llvm;
77 
78 #define DEBUG_TYPE "dagcombine"
79 
80 STATISTIC(NodesCombined   , "Number of dag nodes combined");
81 STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
82 STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
83 STATISTIC(OpsNarrowed     , "Number of load/op/store narrowed");
84 STATISTIC(LdStFP2Int      , "Number of fp load/store pairs transformed to int");
85 STATISTIC(SlicedLoads, "Number of load sliced");
86 STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
87 
88 static cl::opt<bool>
89 CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
90                  cl::desc("Enable DAG combiner's use of IR alias analysis"));
91 
92 static cl::opt<bool>
93 UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
94         cl::desc("Enable DAG combiner's use of TBAA"));
95 
96 #ifndef NDEBUG
97 static cl::opt<std::string>
98 CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
99                    cl::desc("Only use DAG-combiner alias analysis in this"
100                             " function"));
101 #endif
102 
103 /// Hidden option to stress test load slicing, i.e., when this option
104 /// is enabled, load slicing bypasses most of its profitability guards.
105 static cl::opt<bool>
106 StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
107                   cl::desc("Bypass the profitability model of load slicing"),
108                   cl::init(false));
109 
110 static cl::opt<bool>
111   MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
112                     cl::desc("DAG combiner may split indexing from loads"));
113 
114 static cl::opt<bool>
115     EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(true),
116                        cl::desc("DAG combiner enable merging multiple stores "
117                                 "into a wider store"));
118 
119 static cl::opt<unsigned> TokenFactorInlineLimit(
120     "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(2048),
121     cl::desc("Limit the number of operands to inline for Token Factors"));
122 
123 static cl::opt<unsigned> StoreMergeDependenceLimit(
124     "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(10),
125     cl::desc("Limit the number of times for the same StoreNode and RootNode "
126              "to bail out in store merging dependence check"));
127 
128 static cl::opt<bool> EnableReduceLoadOpStoreWidth(
129     "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(true),
130     cl::desc("DAG combiner enable reducing the width of load/op/store "
131              "sequence"));
132 
133 static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
134     "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(true),
135     cl::desc("DAG combiner enable load/<replace bytes>/store with "
136              "a narrower store"));
137 
138 static cl::opt<bool> EnableVectorFCopySignExtendRound(
139     "combiner-vector-fcopysign-extend-round", cl::Hidden, cl::init(false),
140     cl::desc(
141         "Enable merging extends and rounds into FCOPYSIGN on vector types"));
142 
143 namespace {
144 
145   class DAGCombiner {
146     SelectionDAG &DAG;
147     const TargetLowering &TLI;
148     const SelectionDAGTargetInfo *STI;
149     CombineLevel Level = BeforeLegalizeTypes;
150     CodeGenOpt::Level OptLevel;
151     bool LegalDAG = false;
152     bool LegalOperations = false;
153     bool LegalTypes = false;
154     bool ForCodeSize;
155     bool DisableGenericCombines;
156 
157     /// Worklist of all of the nodes that need to be simplified.
158     ///
159     /// This must behave as a stack -- new nodes to process are pushed onto the
160     /// back and when processing we pop off of the back.
161     ///
162     /// The worklist will not contain duplicates but may contain null entries
163     /// due to nodes being deleted from the underlying DAG.
164     SmallVector<SDNode *, 64> Worklist;
165 
166     /// Mapping from an SDNode to its position on the worklist.
167     ///
168     /// This is used to find and remove nodes from the worklist (by nulling
169     /// them) when they are deleted from the underlying DAG. It relies on
170     /// stable indices of nodes within the worklist.
171     DenseMap<SDNode *, unsigned> WorklistMap;
172     /// This records all nodes attempted to add to the worklist since we
173     /// considered a new worklist entry. As we keep do not add duplicate nodes
174     /// in the worklist, this is different from the tail of the worklist.
175     SmallSetVector<SDNode *, 32> PruningList;
176 
177     /// Set of nodes which have been combined (at least once).
178     ///
179     /// This is used to allow us to reliably add any operands of a DAG node
180     /// which have not yet been combined to the worklist.
181     SmallPtrSet<SDNode *, 32> CombinedNodes;
182 
183     /// Map from candidate StoreNode to the pair of RootNode and count.
184     /// The count is used to track how many times we have seen the StoreNode
185     /// with the same RootNode bail out in dependence check. If we have seen
186     /// the bail out for the same pair many times over a limit, we won't
187     /// consider the StoreNode with the same RootNode as store merging
188     /// candidate again.
189     DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
190 
191     // AA - Used for DAG load/store alias analysis.
192     AliasAnalysis *AA;
193 
194     /// When an instruction is simplified, add all users of the instruction to
195     /// the work lists because they might get more simplified now.
AddUsersToWorklist(SDNode * N)196     void AddUsersToWorklist(SDNode *N) {
197       for (SDNode *Node : N->uses())
198         AddToWorklist(Node);
199     }
200 
201     /// Convenient shorthand to add a node and all of its user to the worklist.
AddToWorklistWithUsers(SDNode * N)202     void AddToWorklistWithUsers(SDNode *N) {
203       AddUsersToWorklist(N);
204       AddToWorklist(N);
205     }
206 
207     // Prune potentially dangling nodes. This is called after
208     // any visit to a node, but should also be called during a visit after any
209     // failed combine which may have created a DAG node.
clearAddedDanglingWorklistEntries()210     void clearAddedDanglingWorklistEntries() {
211       // Check any nodes added to the worklist to see if they are prunable.
212       while (!PruningList.empty()) {
213         auto *N = PruningList.pop_back_val();
214         if (N->use_empty())
215           recursivelyDeleteUnusedNodes(N);
216       }
217     }
218 
getNextWorklistEntry()219     SDNode *getNextWorklistEntry() {
220       // Before we do any work, remove nodes that are not in use.
221       clearAddedDanglingWorklistEntries();
222       SDNode *N = nullptr;
223       // The Worklist holds the SDNodes in order, but it may contain null
224       // entries.
225       while (!N && !Worklist.empty()) {
226         N = Worklist.pop_back_val();
227       }
228 
229       if (N) {
230         bool GoodWorklistEntry = WorklistMap.erase(N);
231         (void)GoodWorklistEntry;
232         assert(GoodWorklistEntry &&
233                "Found a worklist entry without a corresponding map entry!");
234       }
235       return N;
236     }
237 
238     /// Call the node-specific routine that folds each particular type of node.
239     SDValue visit(SDNode *N);
240 
241   public:
DAGCombiner(SelectionDAG & D,AliasAnalysis * AA,CodeGenOpt::Level OL)242     DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOpt::Level OL)
243         : DAG(D), TLI(D.getTargetLoweringInfo()),
244           STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
245       ForCodeSize = DAG.shouldOptForSize();
246       DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
247 
248       MaximumLegalStoreInBits = 0;
249       // We use the minimum store size here, since that's all we can guarantee
250       // for the scalable vector types.
251       for (MVT VT : MVT::all_valuetypes())
252         if (EVT(VT).isSimple() && VT != MVT::Other &&
253             TLI.isTypeLegal(EVT(VT)) &&
254             VT.getSizeInBits().getKnownMinValue() >= MaximumLegalStoreInBits)
255           MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinValue();
256     }
257 
ConsiderForPruning(SDNode * N)258     void ConsiderForPruning(SDNode *N) {
259       // Mark this for potential pruning.
260       PruningList.insert(N);
261     }
262 
263     /// Add to the worklist making sure its instance is at the back (next to be
264     /// processed.)
AddToWorklist(SDNode * N)265     void AddToWorklist(SDNode *N) {
266       assert(N->getOpcode() != ISD::DELETED_NODE &&
267              "Deleted Node added to Worklist");
268 
269       // Skip handle nodes as they can't usefully be combined and confuse the
270       // zero-use deletion strategy.
271       if (N->getOpcode() == ISD::HANDLENODE)
272         return;
273 
274       ConsiderForPruning(N);
275 
276       if (WorklistMap.insert(std::make_pair(N, Worklist.size())).second)
277         Worklist.push_back(N);
278     }
279 
280     /// Remove all instances of N from the worklist.
removeFromWorklist(SDNode * N)281     void removeFromWorklist(SDNode *N) {
282       CombinedNodes.erase(N);
283       PruningList.remove(N);
284       StoreRootCountMap.erase(N);
285 
286       auto It = WorklistMap.find(N);
287       if (It == WorklistMap.end())
288         return; // Not in the worklist.
289 
290       // Null out the entry rather than erasing it to avoid a linear operation.
291       Worklist[It->second] = nullptr;
292       WorklistMap.erase(It);
293     }
294 
295     void deleteAndRecombine(SDNode *N);
296     bool recursivelyDeleteUnusedNodes(SDNode *N);
297 
298     /// Replaces all uses of the results of one DAG node with new values.
299     SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
300                       bool AddTo = true);
301 
302     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res,bool AddTo=true)303     SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
304       return CombineTo(N, &Res, 1, AddTo);
305     }
306 
307     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo=true)308     SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
309                       bool AddTo = true) {
310       SDValue To[] = { Res0, Res1 };
311       return CombineTo(N, To, 2, AddTo);
312     }
313 
314     void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
315 
316   private:
317     unsigned MaximumLegalStoreInBits;
318 
319     /// Check the specified integer node value to see if it can be simplified or
320     /// if things it uses can be simplified by bit propagation.
321     /// If so, return true.
SimplifyDemandedBits(SDValue Op)322     bool SimplifyDemandedBits(SDValue Op) {
323       unsigned BitWidth = Op.getScalarValueSizeInBits();
324       APInt DemandedBits = APInt::getAllOnes(BitWidth);
325       return SimplifyDemandedBits(Op, DemandedBits);
326     }
327 
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits)328     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
329       TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
330       KnownBits Known;
331       if (!TLI.SimplifyDemandedBits(Op, DemandedBits, Known, TLO, 0, false))
332         return false;
333 
334       // Revisit the node.
335       AddToWorklist(Op.getNode());
336 
337       CommitTargetLoweringOpt(TLO);
338       return true;
339     }
340 
341     /// Check the specified vector node value to see if it can be simplified or
342     /// if things it uses can be simplified as it only uses some of the
343     /// elements. If so, return true.
SimplifyDemandedVectorElts(SDValue Op)344     bool SimplifyDemandedVectorElts(SDValue Op) {
345       // TODO: For now just pretend it cannot be simplified.
346       if (Op.getValueType().isScalableVector())
347         return false;
348 
349       unsigned NumElts = Op.getValueType().getVectorNumElements();
350       APInt DemandedElts = APInt::getAllOnes(NumElts);
351       return SimplifyDemandedVectorElts(Op, DemandedElts);
352     }
353 
354     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
355                               const APInt &DemandedElts,
356                               bool AssumeSingleUse = false);
357     bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
358                                     bool AssumeSingleUse = false);
359 
360     bool CombineToPreIndexedLoadStore(SDNode *N);
361     bool CombineToPostIndexedLoadStore(SDNode *N);
362     SDValue SplitIndexingFromLoad(LoadSDNode *LD);
363     bool SliceUpLoad(SDNode *N);
364 
365     // Scalars have size 0 to distinguish from singleton vectors.
366     SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
367     bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
368     bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
369 
370     /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
371     ///   load.
372     ///
373     /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
374     /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
375     /// \param EltNo index of the vector element to load.
376     /// \param OriginalLoad load that EVE came from to be replaced.
377     /// \returns EVE on success SDValue() on failure.
378     SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
379                                          SDValue EltNo,
380                                          LoadSDNode *OriginalLoad);
381     void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
382     SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
383     SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
384     SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
385     SDValue PromoteIntBinOp(SDValue Op);
386     SDValue PromoteIntShiftOp(SDValue Op);
387     SDValue PromoteExtend(SDValue Op);
388     bool PromoteLoad(SDValue Op);
389 
390     SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
391                                 SDValue RHS, SDValue True, SDValue False,
392                                 ISD::CondCode CC);
393 
394     /// Call the node-specific routine that knows how to fold each
395     /// particular type of node. If that doesn't do anything, try the
396     /// target-specific DAG combines.
397     SDValue combine(SDNode *N);
398 
399     // Visitation implementation - Implement dag node combining for different
400     // node types.  The semantics are as follows:
401     // Return Value:
402     //   SDValue.getNode() == 0 - No change was made
403     //   SDValue.getNode() == N - N was replaced, is dead and has been handled.
404     //   otherwise              - N should be replaced by the returned Operand.
405     //
406     SDValue visitTokenFactor(SDNode *N);
407     SDValue visitMERGE_VALUES(SDNode *N);
408     SDValue visitADD(SDNode *N);
409     SDValue visitADDLike(SDNode *N);
410     SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
411     SDValue visitSUB(SDNode *N);
412     SDValue visitADDSAT(SDNode *N);
413     SDValue visitSUBSAT(SDNode *N);
414     SDValue visitADDC(SDNode *N);
415     SDValue visitADDO(SDNode *N);
416     SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
417     SDValue visitSUBC(SDNode *N);
418     SDValue visitSUBO(SDNode *N);
419     SDValue visitADDE(SDNode *N);
420     SDValue visitADDCARRY(SDNode *N);
421     SDValue visitSADDO_CARRY(SDNode *N);
422     SDValue visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, SDNode *N);
423     SDValue visitSUBE(SDNode *N);
424     SDValue visitSUBCARRY(SDNode *N);
425     SDValue visitSSUBO_CARRY(SDNode *N);
426     SDValue visitMUL(SDNode *N);
427     SDValue visitMULFIX(SDNode *N);
428     SDValue useDivRem(SDNode *N);
429     SDValue visitSDIV(SDNode *N);
430     SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
431     SDValue visitUDIV(SDNode *N);
432     SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
433     SDValue visitREM(SDNode *N);
434     SDValue visitMULHU(SDNode *N);
435     SDValue visitMULHS(SDNode *N);
436     SDValue visitAVG(SDNode *N);
437     SDValue visitSMUL_LOHI(SDNode *N);
438     SDValue visitUMUL_LOHI(SDNode *N);
439     SDValue visitMULO(SDNode *N);
440     SDValue visitIMINMAX(SDNode *N);
441     SDValue visitAND(SDNode *N);
442     SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
443     SDValue visitOR(SDNode *N);
444     SDValue visitORLike(SDValue N0, SDValue N1, SDNode *N);
445     SDValue visitXOR(SDNode *N);
446     SDValue SimplifyVCastOp(SDNode *N, const SDLoc &DL);
447     SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
448     SDValue visitSHL(SDNode *N);
449     SDValue visitSRA(SDNode *N);
450     SDValue visitSRL(SDNode *N);
451     SDValue visitFunnelShift(SDNode *N);
452     SDValue visitSHLSAT(SDNode *N);
453     SDValue visitRotate(SDNode *N);
454     SDValue visitABS(SDNode *N);
455     SDValue visitBSWAP(SDNode *N);
456     SDValue visitBITREVERSE(SDNode *N);
457     SDValue visitCTLZ(SDNode *N);
458     SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
459     SDValue visitCTTZ(SDNode *N);
460     SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
461     SDValue visitCTPOP(SDNode *N);
462     SDValue visitSELECT(SDNode *N);
463     SDValue visitVSELECT(SDNode *N);
464     SDValue visitSELECT_CC(SDNode *N);
465     SDValue visitSETCC(SDNode *N);
466     SDValue visitSETCCCARRY(SDNode *N);
467     SDValue visitSIGN_EXTEND(SDNode *N);
468     SDValue visitZERO_EXTEND(SDNode *N);
469     SDValue visitANY_EXTEND(SDNode *N);
470     SDValue visitAssertExt(SDNode *N);
471     SDValue visitAssertAlign(SDNode *N);
472     SDValue visitSIGN_EXTEND_INREG(SDNode *N);
473     SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
474     SDValue visitTRUNCATE(SDNode *N);
475     SDValue visitBITCAST(SDNode *N);
476     SDValue visitFREEZE(SDNode *N);
477     SDValue visitBUILD_PAIR(SDNode *N);
478     SDValue visitFADD(SDNode *N);
479     SDValue visitSTRICT_FADD(SDNode *N);
480     SDValue visitFSUB(SDNode *N);
481     SDValue visitFMUL(SDNode *N);
482     SDValue visitFMA(SDNode *N);
483     SDValue visitFDIV(SDNode *N);
484     SDValue visitFREM(SDNode *N);
485     SDValue visitFSQRT(SDNode *N);
486     SDValue visitFCOPYSIGN(SDNode *N);
487     SDValue visitFPOW(SDNode *N);
488     SDValue visitSINT_TO_FP(SDNode *N);
489     SDValue visitUINT_TO_FP(SDNode *N);
490     SDValue visitFP_TO_SINT(SDNode *N);
491     SDValue visitFP_TO_UINT(SDNode *N);
492     SDValue visitFP_ROUND(SDNode *N);
493     SDValue visitFP_EXTEND(SDNode *N);
494     SDValue visitFNEG(SDNode *N);
495     SDValue visitFABS(SDNode *N);
496     SDValue visitFCEIL(SDNode *N);
497     SDValue visitFTRUNC(SDNode *N);
498     SDValue visitFFLOOR(SDNode *N);
499     SDValue visitFMinMax(SDNode *N);
500     SDValue visitBRCOND(SDNode *N);
501     SDValue visitBR_CC(SDNode *N);
502     SDValue visitLOAD(SDNode *N);
503 
504     SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
505     SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
506 
507     bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N);
508 
509     SDValue visitSTORE(SDNode *N);
510     SDValue visitLIFETIME_END(SDNode *N);
511     SDValue visitINSERT_VECTOR_ELT(SDNode *N);
512     SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
513     SDValue visitBUILD_VECTOR(SDNode *N);
514     SDValue visitCONCAT_VECTORS(SDNode *N);
515     SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
516     SDValue visitVECTOR_SHUFFLE(SDNode *N);
517     SDValue visitSCALAR_TO_VECTOR(SDNode *N);
518     SDValue visitINSERT_SUBVECTOR(SDNode *N);
519     SDValue visitMLOAD(SDNode *N);
520     SDValue visitMSTORE(SDNode *N);
521     SDValue visitMGATHER(SDNode *N);
522     SDValue visitMSCATTER(SDNode *N);
523     SDValue visitVPGATHER(SDNode *N);
524     SDValue visitVPSCATTER(SDNode *N);
525     SDValue visitFP_TO_FP16(SDNode *N);
526     SDValue visitFP16_TO_FP(SDNode *N);
527     SDValue visitFP_TO_BF16(SDNode *N);
528     SDValue visitVECREDUCE(SDNode *N);
529     SDValue visitVPOp(SDNode *N);
530 
531     SDValue visitFADDForFMACombine(SDNode *N);
532     SDValue visitFSUBForFMACombine(SDNode *N);
533     SDValue visitFMULForFMADistributiveCombine(SDNode *N);
534 
535     SDValue XformToShuffleWithZero(SDNode *N);
536     bool reassociationCanBreakAddressingModePattern(unsigned Opc,
537                                                     const SDLoc &DL,
538                                                     SDNode *N,
539                                                     SDValue N0,
540                                                     SDValue N1);
541     SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
542                                       SDValue N1);
543     SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
544                            SDValue N1, SDNodeFlags Flags);
545 
546     SDValue visitShiftByConstant(SDNode *N);
547 
548     SDValue foldSelectOfConstants(SDNode *N);
549     SDValue foldVSelectOfConstants(SDNode *N);
550     SDValue foldBinOpIntoSelect(SDNode *BO);
551     bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
552     SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
553     SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
554     SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
555                              SDValue N2, SDValue N3, ISD::CondCode CC,
556                              bool NotExtCompare = false);
557     SDValue convertSelectOfFPConstantsToLoadOffset(
558         const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
559         ISD::CondCode CC);
560     SDValue foldSignChangeInBitcast(SDNode *N);
561     SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
562                                    SDValue N2, SDValue N3, ISD::CondCode CC);
563     SDValue foldSelectOfBinops(SDNode *N);
564     SDValue foldSextSetcc(SDNode *N);
565     SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
566                               const SDLoc &DL);
567     SDValue foldSubToUSubSat(EVT DstVT, SDNode *N);
568     SDValue foldABSToABD(SDNode *N);
569     SDValue unfoldMaskedMerge(SDNode *N);
570     SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
571     SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
572                           const SDLoc &DL, bool foldBooleans);
573     SDValue rebuildSetCC(SDValue N);
574 
575     bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
576                            SDValue &CC, bool MatchStrict = false) const;
577     bool isOneUseSetCC(SDValue N) const;
578 
579     SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
580                                          unsigned HiOp);
581     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
582     SDValue CombineExtLoad(SDNode *N);
583     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
584     SDValue combineRepeatedFPDivisors(SDNode *N);
585     SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
586     SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
587     SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
588     SDValue BuildSDIV(SDNode *N);
589     SDValue BuildSDIVPow2(SDNode *N);
590     SDValue BuildUDIV(SDNode *N);
591     SDValue BuildSREMPow2(SDNode *N);
592     SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
593     SDValue BuildLogBase2(SDValue V, const SDLoc &DL);
594     SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
595     SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
596     SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
597     SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
598     SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
599                                 SDNodeFlags Flags, bool Reciprocal);
600     SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
601                                 SDNodeFlags Flags, bool Reciprocal);
602     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
603                                bool DemandHighBits = true);
604     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
605     SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
606                               SDValue InnerPos, SDValue InnerNeg, bool HasPos,
607                               unsigned PosOpcode, unsigned NegOpcode,
608                               const SDLoc &DL);
609     SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
610                               SDValue InnerPos, SDValue InnerNeg, bool HasPos,
611                               unsigned PosOpcode, unsigned NegOpcode,
612                               const SDLoc &DL);
613     SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
614     SDValue MatchLoadCombine(SDNode *N);
615     SDValue mergeTruncStores(StoreSDNode *N);
616     SDValue reduceLoadWidth(SDNode *N);
617     SDValue ReduceLoadOpStoreWidth(SDNode *N);
618     SDValue splitMergedValStore(StoreSDNode *ST);
619     SDValue TransformFPLoadStorePair(SDNode *N);
620     SDValue convertBuildVecZextToZext(SDNode *N);
621     SDValue convertBuildVecZextToBuildVecWithZeros(SDNode *N);
622     SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
623     SDValue reduceBuildVecTruncToBitCast(SDNode *N);
624     SDValue reduceBuildVecToShuffle(SDNode *N);
625     SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
626                                   ArrayRef<int> VectorMask, SDValue VecIn1,
627                                   SDValue VecIn2, unsigned LeftIdx,
628                                   bool DidSplitVec);
629     SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
630 
631     /// Walk up chain skipping non-aliasing memory nodes,
632     /// looking for aliasing nodes and adding them to the Aliases vector.
633     void GatherAllAliases(SDNode *N, SDValue OriginalChain,
634                           SmallVectorImpl<SDValue> &Aliases);
635 
636     /// Return true if there is any possibility that the two addresses overlap.
637     bool mayAlias(SDNode *Op0, SDNode *Op1) const;
638 
639     /// Walk up chain skipping non-aliasing memory nodes, looking for a better
640     /// chain (aliasing node.)
641     SDValue FindBetterChain(SDNode *N, SDValue Chain);
642 
643     /// Try to replace a store and any possibly adjacent stores on
644     /// consecutive chains with better chains. Return true only if St is
645     /// replaced.
646     ///
647     /// Notice that other chains may still be replaced even if the function
648     /// returns false.
649     bool findBetterNeighborChains(StoreSDNode *St);
650 
651     // Helper for findBetterNeighborChains. Walk up store chain add additional
652     // chained stores that do not overlap and can be parallelized.
653     bool parallelizeChainedStores(StoreSDNode *St);
654 
655     /// Holds a pointer to an LSBaseSDNode as well as information on where it
656     /// is located in a sequence of memory operations connected by a chain.
657     struct MemOpLink {
658       // Ptr to the mem node.
659       LSBaseSDNode *MemNode;
660 
661       // Offset from the base ptr.
662       int64_t OffsetFromBase;
663 
MemOpLink__anonbd6f1c500111::DAGCombiner::MemOpLink664       MemOpLink(LSBaseSDNode *N, int64_t Offset)
665           : MemNode(N), OffsetFromBase(Offset) {}
666     };
667 
668     // Classify the origin of a stored value.
669     enum class StoreSource { Unknown, Constant, Extract, Load };
getStoreSource(SDValue StoreVal)670     StoreSource getStoreSource(SDValue StoreVal) {
671       switch (StoreVal.getOpcode()) {
672       case ISD::Constant:
673       case ISD::ConstantFP:
674         return StoreSource::Constant;
675       case ISD::EXTRACT_VECTOR_ELT:
676       case ISD::EXTRACT_SUBVECTOR:
677         return StoreSource::Extract;
678       case ISD::LOAD:
679         return StoreSource::Load;
680       default:
681         return StoreSource::Unknown;
682       }
683     }
684 
685     /// This is a helper function for visitMUL to check the profitability
686     /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
687     /// MulNode is the original multiply, AddNode is (add x, c1),
688     /// and ConstNode is c2.
689     bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
690                                      SDValue ConstNode);
691 
692     /// This is a helper function for visitAND and visitZERO_EXTEND.  Returns
693     /// true if the (and (load x) c) pattern matches an extload.  ExtVT returns
694     /// the type of the loaded value to be extended.
695     bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
696                           EVT LoadResultTy, EVT &ExtVT);
697 
698     /// Helper function to calculate whether the given Load/Store can have its
699     /// width reduced to ExtVT.
700     bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
701                            EVT &MemVT, unsigned ShAmt = 0);
702 
703     /// Used by BackwardsPropagateMask to find suitable loads.
704     bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
705                            SmallPtrSetImpl<SDNode*> &NodesWithConsts,
706                            ConstantSDNode *Mask, SDNode *&NodeToMask);
707     /// Attempt to propagate a given AND node back to load leaves so that they
708     /// can be combined into narrow loads.
709     bool BackwardsPropagateMask(SDNode *N);
710 
711     /// Helper function for mergeConsecutiveStores which merges the component
712     /// store chains.
713     SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
714                                 unsigned NumStores);
715 
716     /// This is a helper function for mergeConsecutiveStores. When the source
717     /// elements of the consecutive stores are all constants or all extracted
718     /// vector elements, try to merge them into one larger store introducing
719     /// bitcasts if necessary.  \return True if a merged store was created.
720     bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
721                                          EVT MemVT, unsigned NumStores,
722                                          bool IsConstantSrc, bool UseVector,
723                                          bool UseTrunc);
724 
725     /// This is a helper function for mergeConsecutiveStores. Stores that
726     /// potentially may be merged with St are placed in StoreNodes. RootNode is
727     /// a chain predecessor to all store candidates.
728     void getStoreMergeCandidates(StoreSDNode *St,
729                                  SmallVectorImpl<MemOpLink> &StoreNodes,
730                                  SDNode *&Root);
731 
732     /// Helper function for mergeConsecutiveStores. Checks if candidate stores
733     /// have indirect dependency through their operands. RootNode is the
734     /// predecessor to all stores calculated by getStoreMergeCandidates and is
735     /// used to prune the dependency check. \return True if safe to merge.
736     bool checkMergeStoreCandidatesForDependencies(
737         SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
738         SDNode *RootNode);
739 
740     /// This is a helper function for mergeConsecutiveStores. Given a list of
741     /// store candidates, find the first N that are consecutive in memory.
742     /// Returns 0 if there are not at least 2 consecutive stores to try merging.
743     unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
744                                   int64_t ElementSizeBytes) const;
745 
746     /// This is a helper function for mergeConsecutiveStores. It is used for
747     /// store chains that are composed entirely of constant values.
748     bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
749                                   unsigned NumConsecutiveStores,
750                                   EVT MemVT, SDNode *Root, bool AllowVectors);
751 
752     /// This is a helper function for mergeConsecutiveStores. It is used for
753     /// store chains that are composed entirely of extracted vector elements.
754     /// When extracting multiple vector elements, try to store them in one
755     /// vector store rather than a sequence of scalar stores.
756     bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
757                                  unsigned NumConsecutiveStores, EVT MemVT,
758                                  SDNode *Root);
759 
760     /// This is a helper function for mergeConsecutiveStores. It is used for
761     /// store chains that are composed entirely of loaded values.
762     bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
763                               unsigned NumConsecutiveStores, EVT MemVT,
764                               SDNode *Root, bool AllowVectors,
765                               bool IsNonTemporalStore, bool IsNonTemporalLoad);
766 
767     /// Merge consecutive store operations into a wide store.
768     /// This optimization uses wide integers or vectors when possible.
769     /// \return true if stores were merged.
770     bool mergeConsecutiveStores(StoreSDNode *St);
771 
772     /// Try to transform a truncation where C is a constant:
773     ///     (trunc (and X, C)) -> (and (trunc X), (trunc C))
774     ///
775     /// \p N needs to be a truncation and its first operand an AND. Other
776     /// requirements are checked by the function (e.g. that trunc is
777     /// single-use) and if missed an empty SDValue is returned.
778     SDValue distributeTruncateThroughAnd(SDNode *N);
779 
780     /// Helper function to determine whether the target supports operation
781     /// given by \p Opcode for type \p VT, that is, whether the operation
782     /// is legal or custom before legalizing operations, and whether is
783     /// legal (but not custom) after legalization.
hasOperation(unsigned Opcode,EVT VT)784     bool hasOperation(unsigned Opcode, EVT VT) {
785       return TLI.isOperationLegalOrCustom(Opcode, VT, LegalOperations);
786     }
787 
788   public:
789     /// Runs the dag combiner on all nodes in the work list
790     void Run(CombineLevel AtLevel);
791 
getDAG() const792     SelectionDAG &getDAG() const { return DAG; }
793 
794     /// Returns a type large enough to hold any valid shift amount - before type
795     /// legalization these can be huge.
getShiftAmountTy(EVT LHSTy)796     EVT getShiftAmountTy(EVT LHSTy) {
797       assert(LHSTy.isInteger() && "Shift amount is not an integer type!");
798       return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout(), LegalTypes);
799     }
800 
801     /// This method returns true if we are running before type legalization or
802     /// if the specified VT is legal.
isTypeLegal(const EVT & VT)803     bool isTypeLegal(const EVT &VT) {
804       if (!LegalTypes) return true;
805       return TLI.isTypeLegal(VT);
806     }
807 
808     /// Convenience wrapper around TargetLowering::getSetCCResultType
getSetCCResultType(EVT VT) const809     EVT getSetCCResultType(EVT VT) const {
810       return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
811     }
812 
813     void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
814                          SDValue OrigLoad, SDValue ExtLoad,
815                          ISD::NodeType ExtType);
816   };
817 
818 /// This class is a DAGUpdateListener that removes any deleted
819 /// nodes from the worklist.
820 class WorklistRemover : public SelectionDAG::DAGUpdateListener {
821   DAGCombiner &DC;
822 
823 public:
WorklistRemover(DAGCombiner & dc)824   explicit WorklistRemover(DAGCombiner &dc)
825     : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
826 
NodeDeleted(SDNode * N,SDNode * E)827   void NodeDeleted(SDNode *N, SDNode *E) override {
828     DC.removeFromWorklist(N);
829   }
830 };
831 
832 class WorklistInserter : public SelectionDAG::DAGUpdateListener {
833   DAGCombiner &DC;
834 
835 public:
WorklistInserter(DAGCombiner & dc)836   explicit WorklistInserter(DAGCombiner &dc)
837       : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
838 
839   // FIXME: Ideally we could add N to the worklist, but this causes exponential
840   //        compile time costs in large DAGs, e.g. Halide.
NodeInserted(SDNode * N)841   void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
842 };
843 
844 } // end anonymous namespace
845 
846 //===----------------------------------------------------------------------===//
847 //  TargetLowering::DAGCombinerInfo implementation
848 //===----------------------------------------------------------------------===//
849 
AddToWorklist(SDNode * N)850 void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
851   ((DAGCombiner*)DC)->AddToWorklist(N);
852 }
853 
854 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,ArrayRef<SDValue> To,bool AddTo)855 CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
856   return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
857 }
858 
859 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res,bool AddTo)860 CombineTo(SDNode *N, SDValue Res, bool AddTo) {
861   return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
862 }
863 
864 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo)865 CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
866   return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
867 }
868 
869 bool TargetLowering::DAGCombinerInfo::
recursivelyDeleteUnusedNodes(SDNode * N)870 recursivelyDeleteUnusedNodes(SDNode *N) {
871   return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
872 }
873 
874 void TargetLowering::DAGCombinerInfo::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)875 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
876   return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
877 }
878 
879 //===----------------------------------------------------------------------===//
880 // Helper Functions
881 //===----------------------------------------------------------------------===//
882 
deleteAndRecombine(SDNode * N)883 void DAGCombiner::deleteAndRecombine(SDNode *N) {
884   removeFromWorklist(N);
885 
886   // If the operands of this node are only used by the node, they will now be
887   // dead. Make sure to re-visit them and recursively delete dead nodes.
888   for (const SDValue &Op : N->ops())
889     // For an operand generating multiple values, one of the values may
890     // become dead allowing further simplification (e.g. split index
891     // arithmetic from an indexed load).
892     if (Op->hasOneUse() || Op->getNumValues() > 1)
893       AddToWorklist(Op.getNode());
894 
895   DAG.DeleteNode(N);
896 }
897 
898 // APInts must be the same size for most operations, this helper
899 // function zero extends the shorter of the pair so that they match.
900 // We provide an Offset so that we can create bitwidths that won't overflow.
zeroExtendToMatch(APInt & LHS,APInt & RHS,unsigned Offset=0)901 static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
902   unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
903   LHS = LHS.zext(Bits);
904   RHS = RHS.zext(Bits);
905 }
906 
907 // Return true if this node is a setcc, or is a select_cc
908 // that selects between the target values used for true and false, making it
909 // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
910 // the appropriate nodes based on the type of node we are checking. This
911 // simplifies life a bit for the callers.
isSetCCEquivalent(SDValue N,SDValue & LHS,SDValue & RHS,SDValue & CC,bool MatchStrict) const912 bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
913                                     SDValue &CC, bool MatchStrict) const {
914   if (N.getOpcode() == ISD::SETCC) {
915     LHS = N.getOperand(0);
916     RHS = N.getOperand(1);
917     CC  = N.getOperand(2);
918     return true;
919   }
920 
921   if (MatchStrict &&
922       (N.getOpcode() == ISD::STRICT_FSETCC ||
923        N.getOpcode() == ISD::STRICT_FSETCCS)) {
924     LHS = N.getOperand(1);
925     RHS = N.getOperand(2);
926     CC  = N.getOperand(3);
927     return true;
928   }
929 
930   if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N.getOperand(2)) ||
931       !TLI.isConstFalseVal(N.getOperand(3)))
932     return false;
933 
934   if (TLI.getBooleanContents(N.getValueType()) ==
935       TargetLowering::UndefinedBooleanContent)
936     return false;
937 
938   LHS = N.getOperand(0);
939   RHS = N.getOperand(1);
940   CC  = N.getOperand(4);
941   return true;
942 }
943 
944 /// Return true if this is a SetCC-equivalent operation with only one use.
945 /// If this is true, it allows the users to invert the operation for free when
946 /// it is profitable to do so.
isOneUseSetCC(SDValue N) const947 bool DAGCombiner::isOneUseSetCC(SDValue N) const {
948   SDValue N0, N1, N2;
949   if (isSetCCEquivalent(N, N0, N1, N2) && N->hasOneUse())
950     return true;
951   return false;
952 }
953 
isConstantSplatVectorMaskForType(SDNode * N,EVT ScalarTy)954 static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
955   if (!ScalarTy.isSimple())
956     return false;
957 
958   uint64_t MaskForTy = 0ULL;
959   switch (ScalarTy.getSimpleVT().SimpleTy) {
960   case MVT::i8:
961     MaskForTy = 0xFFULL;
962     break;
963   case MVT::i16:
964     MaskForTy = 0xFFFFULL;
965     break;
966   case MVT::i32:
967     MaskForTy = 0xFFFFFFFFULL;
968     break;
969   default:
970     return false;
971     break;
972   }
973 
974   APInt Val;
975   if (ISD::isConstantSplatVector(N, Val))
976     return Val.getLimitedValue() == MaskForTy;
977 
978   return false;
979 }
980 
981 // Determines if it is a constant integer or a splat/build vector of constant
982 // integers (and undefs).
983 // Do not permit build vector implicit truncation.
isConstantOrConstantVector(SDValue N,bool NoOpaques=false)984 static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
985   if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N))
986     return !(Const->isOpaque() && NoOpaques);
987   if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
988     return false;
989   unsigned BitWidth = N.getScalarValueSizeInBits();
990   for (const SDValue &Op : N->op_values()) {
991     if (Op.isUndef())
992       continue;
993     ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op);
994     if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
995         (Const->isOpaque() && NoOpaques))
996       return false;
997   }
998   return true;
999 }
1000 
1001 // Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
1002 // undef's.
isAnyConstantBuildVector(SDValue V,bool NoOpaques=false)1003 static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
1004   if (V.getOpcode() != ISD::BUILD_VECTOR)
1005     return false;
1006   return isConstantOrConstantVector(V, NoOpaques) ||
1007          ISD::isBuildVectorOfConstantFPSDNodes(V.getNode());
1008 }
1009 
1010 // Determine if this an indexed load with an opaque target constant index.
canSplitIdx(LoadSDNode * LD)1011 static bool canSplitIdx(LoadSDNode *LD) {
1012   return MaySplitLoadIndex &&
1013          (LD->getOperand(2).getOpcode() != ISD::TargetConstant ||
1014           !cast<ConstantSDNode>(LD->getOperand(2))->isOpaque());
1015 }
1016 
reassociationCanBreakAddressingModePattern(unsigned Opc,const SDLoc & DL,SDNode * N,SDValue N0,SDValue N1)1017 bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1018                                                              const SDLoc &DL,
1019                                                              SDNode *N,
1020                                                              SDValue N0,
1021                                                              SDValue N1) {
1022   // Currently this only tries to ensure we don't undo the GEP splits done by
1023   // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1024   // we check if the following transformation would be problematic:
1025   // (load/store (add, (add, x, offset1), offset2)) ->
1026   // (load/store (add, x, offset1+offset2)).
1027 
1028   // (load/store (add, (add, x, y), offset2)) ->
1029   // (load/store (add, (add, x, offset2), y)).
1030 
1031   if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
1032     return false;
1033 
1034   auto *C2 = dyn_cast<ConstantSDNode>(N1);
1035   if (!C2)
1036     return false;
1037 
1038   const APInt &C2APIntVal = C2->getAPIntValue();
1039   if (C2APIntVal.getSignificantBits() > 64)
1040     return false;
1041 
1042   if (auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
1043     if (N0.hasOneUse())
1044       return false;
1045 
1046     const APInt &C1APIntVal = C1->getAPIntValue();
1047     const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1048     if (CombinedValueIntVal.getSignificantBits() > 64)
1049       return false;
1050     const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1051 
1052     for (SDNode *Node : N->uses()) {
1053       if (auto *LoadStore = dyn_cast<MemSDNode>(Node)) {
1054         // Is x[offset2] already not a legal addressing mode? If so then
1055         // reassociating the constants breaks nothing (we test offset2 because
1056         // that's the one we hope to fold into the load or store).
1057         TargetLoweringBase::AddrMode AM;
1058         AM.HasBaseReg = true;
1059         AM.BaseOffs = C2APIntVal.getSExtValue();
1060         EVT VT = LoadStore->getMemoryVT();
1061         unsigned AS = LoadStore->getAddressSpace();
1062         Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1063         if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1064           continue;
1065 
1066         // Would x[offset1+offset2] still be a legal addressing mode?
1067         AM.BaseOffs = CombinedValue;
1068         if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1069           return true;
1070       }
1071     }
1072   } else {
1073     if (auto *GA = dyn_cast<GlobalAddressSDNode>(N0.getOperand(1)))
1074       if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1075         return false;
1076 
1077     for (SDNode *Node : N->uses()) {
1078       auto *LoadStore = dyn_cast<MemSDNode>(Node);
1079       if (!LoadStore)
1080         return false;
1081 
1082       // Is x[offset2] a legal addressing mode? If so then
1083       // reassociating the constants breaks address pattern
1084       TargetLoweringBase::AddrMode AM;
1085       AM.HasBaseReg = true;
1086       AM.BaseOffs = C2APIntVal.getSExtValue();
1087       EVT VT = LoadStore->getMemoryVT();
1088       unsigned AS = LoadStore->getAddressSpace();
1089       Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1090       if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1091         return false;
1092     }
1093     return true;
1094   }
1095 
1096   return false;
1097 }
1098 
1099 // Helper for DAGCombiner::reassociateOps. Try to reassociate an expression
1100 // such as (Opc N0, N1), if \p N0 is the same kind of operation as \p Opc.
reassociateOpsCommutative(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1)1101 SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1102                                                SDValue N0, SDValue N1) {
1103   EVT VT = N0.getValueType();
1104 
1105   if (N0.getOpcode() != Opc)
1106     return SDValue();
1107 
1108   SDValue N00 = N0.getOperand(0);
1109   SDValue N01 = N0.getOperand(1);
1110 
1111   if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(N01))) {
1112     if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(N1))) {
1113       // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1114       if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, {N01, N1}))
1115         return DAG.getNode(Opc, DL, VT, N00, OpNode);
1116       return SDValue();
1117     }
1118     if (TLI.isReassocProfitable(DAG, N0, N1)) {
1119       // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1120       //              iff (op x, c1) has one use
1121       SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1);
1122       return DAG.getNode(Opc, DL, VT, OpNode, N01);
1123     }
1124   }
1125 
1126   // Check for repeated operand logic simplifications.
1127   if (Opc == ISD::AND || Opc == ISD::OR) {
1128     // (N00 & N01) & N00 --> N00 & N01
1129     // (N00 & N01) & N01 --> N00 & N01
1130     // (N00 | N01) | N00 --> N00 | N01
1131     // (N00 | N01) | N01 --> N00 | N01
1132     if (N1 == N00 || N1 == N01)
1133       return N0;
1134   }
1135   if (Opc == ISD::XOR) {
1136     // (N00 ^ N01) ^ N00 --> N01
1137     if (N1 == N00)
1138       return N01;
1139     // (N00 ^ N01) ^ N01 --> N00
1140     if (N1 == N01)
1141       return N00;
1142   }
1143 
1144   if (TLI.isReassocProfitable(DAG, N0, N1)) {
1145     if (N1 != N01) {
1146       // Reassociate if (op N00, N1) already exist
1147       if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N00, N1})) {
1148         // if Op (Op N00, N1), N01 already exist
1149         // we need to stop reassciate to avoid dead loop
1150         if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N01}))
1151           return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N01);
1152       }
1153     }
1154 
1155     if (N1 != N00) {
1156       // Reassociate if (op N01, N1) already exist
1157       if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N01, N1})) {
1158         // if Op (Op N01, N1), N00 already exist
1159         // we need to stop reassciate to avoid dead loop
1160         if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N00}))
1161           return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N00);
1162       }
1163     }
1164   }
1165 
1166   return SDValue();
1167 }
1168 
1169 // Try to reassociate commutative binops.
reassociateOps(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1170 SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1171                                     SDValue N1, SDNodeFlags Flags) {
1172   assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1173 
1174   // Floating-point reassociation is not allowed without loose FP math.
1175   if (N0.getValueType().isFloatingPoint() ||
1176       N1.getValueType().isFloatingPoint())
1177     if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1178       return SDValue();
1179 
1180   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1))
1181     return Combined;
1182   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0))
1183     return Combined;
1184   return SDValue();
1185 }
1186 
CombineTo(SDNode * N,const SDValue * To,unsigned NumTo,bool AddTo)1187 SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1188                                bool AddTo) {
1189   assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1190   ++NodesCombined;
1191   LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1192              To[0].dump(&DAG);
1193              dbgs() << " and " << NumTo - 1 << " other values\n");
1194   for (unsigned i = 0, e = NumTo; i != e; ++i)
1195     assert((!To[i].getNode() ||
1196             N->getValueType(i) == To[i].getValueType()) &&
1197            "Cannot combine value to value of different type!");
1198 
1199   WorklistRemover DeadNodes(*this);
1200   DAG.ReplaceAllUsesWith(N, To);
1201   if (AddTo) {
1202     // Push the new nodes and any users onto the worklist
1203     for (unsigned i = 0, e = NumTo; i != e; ++i) {
1204       if (To[i].getNode())
1205         AddToWorklistWithUsers(To[i].getNode());
1206     }
1207   }
1208 
1209   // Finally, if the node is now dead, remove it from the graph.  The node
1210   // may not be dead if the replacement process recursively simplified to
1211   // something else needing this node.
1212   if (N->use_empty())
1213     deleteAndRecombine(N);
1214   return SDValue(N, 0);
1215 }
1216 
1217 void DAGCombiner::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)1218 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1219   // Replace the old value with the new one.
1220   ++NodesCombined;
1221   LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1222              dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1223 
1224   // Replace all uses.
1225   DAG.ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
1226 
1227   // Push the new node and any (possibly new) users onto the worklist.
1228   AddToWorklistWithUsers(TLO.New.getNode());
1229 
1230   // Finally, if the node is now dead, remove it from the graph.
1231   recursivelyDeleteUnusedNodes(TLO.Old.getNode());
1232 }
1233 
1234 /// Check the specified integer node value to see if it can be simplified or if
1235 /// things it uses can be simplified by bit propagation. If so, return true.
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,bool AssumeSingleUse)1236 bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1237                                        const APInt &DemandedElts,
1238                                        bool AssumeSingleUse) {
1239   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1240   KnownBits Known;
1241   if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, 0,
1242                                 AssumeSingleUse))
1243     return false;
1244 
1245   // Revisit the node.
1246   AddToWorklist(Op.getNode());
1247 
1248   CommitTargetLoweringOpt(TLO);
1249   return true;
1250 }
1251 
1252 /// Check the specified vector node value to see if it can be simplified or
1253 /// if things it uses can be simplified as it only uses some of the elements.
1254 /// If so, return true.
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,bool AssumeSingleUse)1255 bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1256                                              const APInt &DemandedElts,
1257                                              bool AssumeSingleUse) {
1258   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1259   APInt KnownUndef, KnownZero;
1260   if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
1261                                       TLO, 0, AssumeSingleUse))
1262     return false;
1263 
1264   // Revisit the node.
1265   AddToWorklist(Op.getNode());
1266 
1267   CommitTargetLoweringOpt(TLO);
1268   return true;
1269 }
1270 
ReplaceLoadWithPromotedLoad(SDNode * Load,SDNode * ExtLoad)1271 void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1272   SDLoc DL(Load);
1273   EVT VT = Load->getValueType(0);
1274   SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
1275 
1276   LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1277              Trunc.dump(&DAG); dbgs() << '\n');
1278 
1279   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
1280   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
1281 
1282   AddToWorklist(Trunc.getNode());
1283   recursivelyDeleteUnusedNodes(Load);
1284 }
1285 
PromoteOperand(SDValue Op,EVT PVT,bool & Replace)1286 SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1287   Replace = false;
1288   SDLoc DL(Op);
1289   if (ISD::isUNINDEXEDLoad(Op.getNode())) {
1290     LoadSDNode *LD = cast<LoadSDNode>(Op);
1291     EVT MemVT = LD->getMemoryVT();
1292     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1293                                                       : LD->getExtensionType();
1294     Replace = true;
1295     return DAG.getExtLoad(ExtType, DL, PVT,
1296                           LD->getChain(), LD->getBasePtr(),
1297                           MemVT, LD->getMemOperand());
1298   }
1299 
1300   unsigned Opc = Op.getOpcode();
1301   switch (Opc) {
1302   default: break;
1303   case ISD::AssertSext:
1304     if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
1305       return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
1306     break;
1307   case ISD::AssertZext:
1308     if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
1309       return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
1310     break;
1311   case ISD::Constant: {
1312     unsigned ExtOpc =
1313       Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1314     return DAG.getNode(ExtOpc, DL, PVT, Op);
1315   }
1316   }
1317 
1318   if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
1319     return SDValue();
1320   return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
1321 }
1322 
SExtPromoteOperand(SDValue Op,EVT PVT)1323 SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1324   if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
1325     return SDValue();
1326   EVT OldVT = Op.getValueType();
1327   SDLoc DL(Op);
1328   bool Replace = false;
1329   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1330   if (!NewOp.getNode())
1331     return SDValue();
1332   AddToWorklist(NewOp.getNode());
1333 
1334   if (Replace)
1335     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1336   return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
1337                      DAG.getValueType(OldVT));
1338 }
1339 
ZExtPromoteOperand(SDValue Op,EVT PVT)1340 SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1341   EVT OldVT = Op.getValueType();
1342   SDLoc DL(Op);
1343   bool Replace = false;
1344   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1345   if (!NewOp.getNode())
1346     return SDValue();
1347   AddToWorklist(NewOp.getNode());
1348 
1349   if (Replace)
1350     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1351   return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
1352 }
1353 
1354 /// Promote the specified integer binary operation if the target indicates it is
1355 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1356 /// i32 since i16 instructions are longer.
PromoteIntBinOp(SDValue Op)1357 SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1358   if (!LegalOperations)
1359     return SDValue();
1360 
1361   EVT VT = Op.getValueType();
1362   if (VT.isVector() || !VT.isInteger())
1363     return SDValue();
1364 
1365   // If operation type is 'undesirable', e.g. i16 on x86, consider
1366   // promoting it.
1367   unsigned Opc = Op.getOpcode();
1368   if (TLI.isTypeDesirableForOp(Opc, VT))
1369     return SDValue();
1370 
1371   EVT PVT = VT;
1372   // Consult target whether it is a good idea to promote this operation and
1373   // what's the right type to promote it to.
1374   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1375     assert(PVT != VT && "Don't know what type to promote to!");
1376 
1377     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1378 
1379     bool Replace0 = false;
1380     SDValue N0 = Op.getOperand(0);
1381     SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
1382 
1383     bool Replace1 = false;
1384     SDValue N1 = Op.getOperand(1);
1385     SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
1386     SDLoc DL(Op);
1387 
1388     SDValue RV =
1389         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
1390 
1391     // We are always replacing N0/N1's use in N and only need additional
1392     // replacements if there are additional uses.
1393     // Note: We are checking uses of the *nodes* (SDNode) rather than values
1394     //       (SDValue) here because the node may reference multiple values
1395     //       (for example, the chain value of a load node).
1396     Replace0 &= !N0->hasOneUse();
1397     Replace1 &= (N0 != N1) && !N1->hasOneUse();
1398 
1399     // Combine Op here so it is preserved past replacements.
1400     CombineTo(Op.getNode(), RV);
1401 
1402     // If operands have a use ordering, make sure we deal with
1403     // predecessor first.
1404     if (Replace0 && Replace1 && N0->isPredecessorOf(N1.getNode())) {
1405       std::swap(N0, N1);
1406       std::swap(NN0, NN1);
1407     }
1408 
1409     if (Replace0) {
1410       AddToWorklist(NN0.getNode());
1411       ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
1412     }
1413     if (Replace1) {
1414       AddToWorklist(NN1.getNode());
1415       ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
1416     }
1417     return Op;
1418   }
1419   return SDValue();
1420 }
1421 
1422 /// Promote the specified integer shift operation if the target indicates it is
1423 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1424 /// i32 since i16 instructions are longer.
PromoteIntShiftOp(SDValue Op)1425 SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1426   if (!LegalOperations)
1427     return SDValue();
1428 
1429   EVT VT = Op.getValueType();
1430   if (VT.isVector() || !VT.isInteger())
1431     return SDValue();
1432 
1433   // If operation type is 'undesirable', e.g. i16 on x86, consider
1434   // promoting it.
1435   unsigned Opc = Op.getOpcode();
1436   if (TLI.isTypeDesirableForOp(Opc, VT))
1437     return SDValue();
1438 
1439   EVT PVT = VT;
1440   // Consult target whether it is a good idea to promote this operation and
1441   // what's the right type to promote it to.
1442   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1443     assert(PVT != VT && "Don't know what type to promote to!");
1444 
1445     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1446 
1447     bool Replace = false;
1448     SDValue N0 = Op.getOperand(0);
1449     if (Opc == ISD::SRA)
1450       N0 = SExtPromoteOperand(N0, PVT);
1451     else if (Opc == ISD::SRL)
1452       N0 = ZExtPromoteOperand(N0, PVT);
1453     else
1454       N0 = PromoteOperand(N0, PVT, Replace);
1455 
1456     if (!N0.getNode())
1457       return SDValue();
1458 
1459     SDLoc DL(Op);
1460     SDValue N1 = Op.getOperand(1);
1461     SDValue RV =
1462         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
1463 
1464     if (Replace)
1465       ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
1466 
1467     // Deal with Op being deleted.
1468     if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1469       return RV;
1470   }
1471   return SDValue();
1472 }
1473 
PromoteExtend(SDValue Op)1474 SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1475   if (!LegalOperations)
1476     return SDValue();
1477 
1478   EVT VT = Op.getValueType();
1479   if (VT.isVector() || !VT.isInteger())
1480     return SDValue();
1481 
1482   // If operation type is 'undesirable', e.g. i16 on x86, consider
1483   // promoting it.
1484   unsigned Opc = Op.getOpcode();
1485   if (TLI.isTypeDesirableForOp(Opc, VT))
1486     return SDValue();
1487 
1488   EVT PVT = VT;
1489   // Consult target whether it is a good idea to promote this operation and
1490   // what's the right type to promote it to.
1491   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1492     assert(PVT != VT && "Don't know what type to promote to!");
1493     // fold (aext (aext x)) -> (aext x)
1494     // fold (aext (zext x)) -> (zext x)
1495     // fold (aext (sext x)) -> (sext x)
1496     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1497     return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
1498   }
1499   return SDValue();
1500 }
1501 
PromoteLoad(SDValue Op)1502 bool DAGCombiner::PromoteLoad(SDValue Op) {
1503   if (!LegalOperations)
1504     return false;
1505 
1506   if (!ISD::isUNINDEXEDLoad(Op.getNode()))
1507     return false;
1508 
1509   EVT VT = Op.getValueType();
1510   if (VT.isVector() || !VT.isInteger())
1511     return false;
1512 
1513   // If operation type is 'undesirable', e.g. i16 on x86, consider
1514   // promoting it.
1515   unsigned Opc = Op.getOpcode();
1516   if (TLI.isTypeDesirableForOp(Opc, VT))
1517     return false;
1518 
1519   EVT PVT = VT;
1520   // Consult target whether it is a good idea to promote this operation and
1521   // what's the right type to promote it to.
1522   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1523     assert(PVT != VT && "Don't know what type to promote to!");
1524 
1525     SDLoc DL(Op);
1526     SDNode *N = Op.getNode();
1527     LoadSDNode *LD = cast<LoadSDNode>(N);
1528     EVT MemVT = LD->getMemoryVT();
1529     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1530                                                       : LD->getExtensionType();
1531     SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
1532                                    LD->getChain(), LD->getBasePtr(),
1533                                    MemVT, LD->getMemOperand());
1534     SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
1535 
1536     LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1537                Result.dump(&DAG); dbgs() << '\n');
1538 
1539     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
1540     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
1541 
1542     AddToWorklist(Result.getNode());
1543     recursivelyDeleteUnusedNodes(N);
1544     return true;
1545   }
1546 
1547   return false;
1548 }
1549 
1550 /// Recursively delete a node which has no uses and any operands for
1551 /// which it is the only use.
1552 ///
1553 /// Note that this both deletes the nodes and removes them from the worklist.
1554 /// It also adds any nodes who have had a user deleted to the worklist as they
1555 /// may now have only one use and subject to other combines.
recursivelyDeleteUnusedNodes(SDNode * N)1556 bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1557   if (!N->use_empty())
1558     return false;
1559 
1560   SmallSetVector<SDNode *, 16> Nodes;
1561   Nodes.insert(N);
1562   do {
1563     N = Nodes.pop_back_val();
1564     if (!N)
1565       continue;
1566 
1567     if (N->use_empty()) {
1568       for (const SDValue &ChildN : N->op_values())
1569         Nodes.insert(ChildN.getNode());
1570 
1571       removeFromWorklist(N);
1572       DAG.DeleteNode(N);
1573     } else {
1574       AddToWorklist(N);
1575     }
1576   } while (!Nodes.empty());
1577   return true;
1578 }
1579 
1580 //===----------------------------------------------------------------------===//
1581 //  Main DAG Combiner implementation
1582 //===----------------------------------------------------------------------===//
1583 
Run(CombineLevel AtLevel)1584 void DAGCombiner::Run(CombineLevel AtLevel) {
1585   // set the instance variables, so that the various visit routines may use it.
1586   Level = AtLevel;
1587   LegalDAG = Level >= AfterLegalizeDAG;
1588   LegalOperations = Level >= AfterLegalizeVectorOps;
1589   LegalTypes = Level >= AfterLegalizeTypes;
1590 
1591   WorklistInserter AddNodes(*this);
1592 
1593   // Add all the dag nodes to the worklist.
1594   for (SDNode &Node : DAG.allnodes())
1595     AddToWorklist(&Node);
1596 
1597   // Create a dummy node (which is not added to allnodes), that adds a reference
1598   // to the root node, preventing it from being deleted, and tracking any
1599   // changes of the root.
1600   HandleSDNode Dummy(DAG.getRoot());
1601 
1602   // While we have a valid worklist entry node, try to combine it.
1603   while (SDNode *N = getNextWorklistEntry()) {
1604     // If N has no uses, it is dead.  Make sure to revisit all N's operands once
1605     // N is deleted from the DAG, since they too may now be dead or may have a
1606     // reduced number of uses, allowing other xforms.
1607     if (recursivelyDeleteUnusedNodes(N))
1608       continue;
1609 
1610     WorklistRemover DeadNodes(*this);
1611 
1612     // If this combine is running after legalizing the DAG, re-legalize any
1613     // nodes pulled off the worklist.
1614     if (LegalDAG) {
1615       SmallSetVector<SDNode *, 16> UpdatedNodes;
1616       bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1617 
1618       for (SDNode *LN : UpdatedNodes)
1619         AddToWorklistWithUsers(LN);
1620 
1621       if (!NIsValid)
1622         continue;
1623     }
1624 
1625     LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1626 
1627     // Add any operands of the new node which have not yet been combined to the
1628     // worklist as well. Because the worklist uniques things already, this
1629     // won't repeatedly process the same operand.
1630     CombinedNodes.insert(N);
1631     for (const SDValue &ChildN : N->op_values())
1632       if (!CombinedNodes.count(ChildN.getNode()))
1633         AddToWorklist(ChildN.getNode());
1634 
1635     SDValue RV = combine(N);
1636 
1637     if (!RV.getNode())
1638       continue;
1639 
1640     ++NodesCombined;
1641 
1642     // If we get back the same node we passed in, rather than a new node or
1643     // zero, we know that the node must have defined multiple values and
1644     // CombineTo was used.  Since CombineTo takes care of the worklist
1645     // mechanics for us, we have no work to do in this case.
1646     if (RV.getNode() == N)
1647       continue;
1648 
1649     assert(N->getOpcode() != ISD::DELETED_NODE &&
1650            RV.getOpcode() != ISD::DELETED_NODE &&
1651            "Node was deleted but visit returned new node!");
1652 
1653     LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1654 
1655     if (N->getNumValues() == RV->getNumValues())
1656       DAG.ReplaceAllUsesWith(N, RV.getNode());
1657     else {
1658       assert(N->getValueType(0) == RV.getValueType() &&
1659              N->getNumValues() == 1 && "Type mismatch");
1660       DAG.ReplaceAllUsesWith(N, &RV);
1661     }
1662 
1663     // Push the new node and any users onto the worklist.  Omit this if the
1664     // new node is the EntryToken (e.g. if a store managed to get optimized
1665     // out), because re-visiting the EntryToken and its users will not uncover
1666     // any additional opportunities, but there may be a large number of such
1667     // users, potentially causing compile time explosion.
1668     if (RV.getOpcode() != ISD::EntryToken) {
1669       AddToWorklist(RV.getNode());
1670       AddUsersToWorklist(RV.getNode());
1671     }
1672 
1673     // Finally, if the node is now dead, remove it from the graph.  The node
1674     // may not be dead if the replacement process recursively simplified to
1675     // something else needing this node. This will also take care of adding any
1676     // operands which have lost a user to the worklist.
1677     recursivelyDeleteUnusedNodes(N);
1678   }
1679 
1680   // If the root changed (e.g. it was a dead load, update the root).
1681   DAG.setRoot(Dummy.getValue());
1682   DAG.RemoveDeadNodes();
1683 }
1684 
visit(SDNode * N)1685 SDValue DAGCombiner::visit(SDNode *N) {
1686   switch (N->getOpcode()) {
1687   default: break;
1688   case ISD::TokenFactor:        return visitTokenFactor(N);
1689   case ISD::MERGE_VALUES:       return visitMERGE_VALUES(N);
1690   case ISD::ADD:                return visitADD(N);
1691   case ISD::SUB:                return visitSUB(N);
1692   case ISD::SADDSAT:
1693   case ISD::UADDSAT:            return visitADDSAT(N);
1694   case ISD::SSUBSAT:
1695   case ISD::USUBSAT:            return visitSUBSAT(N);
1696   case ISD::ADDC:               return visitADDC(N);
1697   case ISD::SADDO:
1698   case ISD::UADDO:              return visitADDO(N);
1699   case ISD::SUBC:               return visitSUBC(N);
1700   case ISD::SSUBO:
1701   case ISD::USUBO:              return visitSUBO(N);
1702   case ISD::ADDE:               return visitADDE(N);
1703   case ISD::ADDCARRY:           return visitADDCARRY(N);
1704   case ISD::SADDO_CARRY:        return visitSADDO_CARRY(N);
1705   case ISD::SUBE:               return visitSUBE(N);
1706   case ISD::SUBCARRY:           return visitSUBCARRY(N);
1707   case ISD::SSUBO_CARRY:        return visitSSUBO_CARRY(N);
1708   case ISD::SMULFIX:
1709   case ISD::SMULFIXSAT:
1710   case ISD::UMULFIX:
1711   case ISD::UMULFIXSAT:         return visitMULFIX(N);
1712   case ISD::MUL:                return visitMUL(N);
1713   case ISD::SDIV:               return visitSDIV(N);
1714   case ISD::UDIV:               return visitUDIV(N);
1715   case ISD::SREM:
1716   case ISD::UREM:               return visitREM(N);
1717   case ISD::MULHU:              return visitMULHU(N);
1718   case ISD::MULHS:              return visitMULHS(N);
1719   case ISD::AVGFLOORS:
1720   case ISD::AVGFLOORU:
1721   case ISD::AVGCEILS:
1722   case ISD::AVGCEILU:           return visitAVG(N);
1723   case ISD::SMUL_LOHI:          return visitSMUL_LOHI(N);
1724   case ISD::UMUL_LOHI:          return visitUMUL_LOHI(N);
1725   case ISD::SMULO:
1726   case ISD::UMULO:              return visitMULO(N);
1727   case ISD::SMIN:
1728   case ISD::SMAX:
1729   case ISD::UMIN:
1730   case ISD::UMAX:               return visitIMINMAX(N);
1731   case ISD::AND:                return visitAND(N);
1732   case ISD::OR:                 return visitOR(N);
1733   case ISD::XOR:                return visitXOR(N);
1734   case ISD::SHL:                return visitSHL(N);
1735   case ISD::SRA:                return visitSRA(N);
1736   case ISD::SRL:                return visitSRL(N);
1737   case ISD::ROTR:
1738   case ISD::ROTL:               return visitRotate(N);
1739   case ISD::FSHL:
1740   case ISD::FSHR:               return visitFunnelShift(N);
1741   case ISD::SSHLSAT:
1742   case ISD::USHLSAT:            return visitSHLSAT(N);
1743   case ISD::ABS:                return visitABS(N);
1744   case ISD::BSWAP:              return visitBSWAP(N);
1745   case ISD::BITREVERSE:         return visitBITREVERSE(N);
1746   case ISD::CTLZ:               return visitCTLZ(N);
1747   case ISD::CTLZ_ZERO_UNDEF:    return visitCTLZ_ZERO_UNDEF(N);
1748   case ISD::CTTZ:               return visitCTTZ(N);
1749   case ISD::CTTZ_ZERO_UNDEF:    return visitCTTZ_ZERO_UNDEF(N);
1750   case ISD::CTPOP:              return visitCTPOP(N);
1751   case ISD::SELECT:             return visitSELECT(N);
1752   case ISD::VSELECT:            return visitVSELECT(N);
1753   case ISD::SELECT_CC:          return visitSELECT_CC(N);
1754   case ISD::SETCC:              return visitSETCC(N);
1755   case ISD::SETCCCARRY:         return visitSETCCCARRY(N);
1756   case ISD::SIGN_EXTEND:        return visitSIGN_EXTEND(N);
1757   case ISD::ZERO_EXTEND:        return visitZERO_EXTEND(N);
1758   case ISD::ANY_EXTEND:         return visitANY_EXTEND(N);
1759   case ISD::AssertSext:
1760   case ISD::AssertZext:         return visitAssertExt(N);
1761   case ISD::AssertAlign:        return visitAssertAlign(N);
1762   case ISD::SIGN_EXTEND_INREG:  return visitSIGN_EXTEND_INREG(N);
1763   case ISD::SIGN_EXTEND_VECTOR_INREG:
1764   case ISD::ZERO_EXTEND_VECTOR_INREG:
1765   case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
1766   case ISD::TRUNCATE:           return visitTRUNCATE(N);
1767   case ISD::BITCAST:            return visitBITCAST(N);
1768   case ISD::BUILD_PAIR:         return visitBUILD_PAIR(N);
1769   case ISD::FADD:               return visitFADD(N);
1770   case ISD::STRICT_FADD:        return visitSTRICT_FADD(N);
1771   case ISD::FSUB:               return visitFSUB(N);
1772   case ISD::FMUL:               return visitFMUL(N);
1773   case ISD::FMA:                return visitFMA(N);
1774   case ISD::FDIV:               return visitFDIV(N);
1775   case ISD::FREM:               return visitFREM(N);
1776   case ISD::FSQRT:              return visitFSQRT(N);
1777   case ISD::FCOPYSIGN:          return visitFCOPYSIGN(N);
1778   case ISD::FPOW:               return visitFPOW(N);
1779   case ISD::SINT_TO_FP:         return visitSINT_TO_FP(N);
1780   case ISD::UINT_TO_FP:         return visitUINT_TO_FP(N);
1781   case ISD::FP_TO_SINT:         return visitFP_TO_SINT(N);
1782   case ISD::FP_TO_UINT:         return visitFP_TO_UINT(N);
1783   case ISD::FP_ROUND:           return visitFP_ROUND(N);
1784   case ISD::FP_EXTEND:          return visitFP_EXTEND(N);
1785   case ISD::FNEG:               return visitFNEG(N);
1786   case ISD::FABS:               return visitFABS(N);
1787   case ISD::FFLOOR:             return visitFFLOOR(N);
1788   case ISD::FMINNUM:
1789   case ISD::FMAXNUM:
1790   case ISD::FMINIMUM:
1791   case ISD::FMAXIMUM:           return visitFMinMax(N);
1792   case ISD::FCEIL:              return visitFCEIL(N);
1793   case ISD::FTRUNC:             return visitFTRUNC(N);
1794   case ISD::BRCOND:             return visitBRCOND(N);
1795   case ISD::BR_CC:              return visitBR_CC(N);
1796   case ISD::LOAD:               return visitLOAD(N);
1797   case ISD::STORE:              return visitSTORE(N);
1798   case ISD::INSERT_VECTOR_ELT:  return visitINSERT_VECTOR_ELT(N);
1799   case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
1800   case ISD::BUILD_VECTOR:       return visitBUILD_VECTOR(N);
1801   case ISD::CONCAT_VECTORS:     return visitCONCAT_VECTORS(N);
1802   case ISD::EXTRACT_SUBVECTOR:  return visitEXTRACT_SUBVECTOR(N);
1803   case ISD::VECTOR_SHUFFLE:     return visitVECTOR_SHUFFLE(N);
1804   case ISD::SCALAR_TO_VECTOR:   return visitSCALAR_TO_VECTOR(N);
1805   case ISD::INSERT_SUBVECTOR:   return visitINSERT_SUBVECTOR(N);
1806   case ISD::MGATHER:            return visitMGATHER(N);
1807   case ISD::MLOAD:              return visitMLOAD(N);
1808   case ISD::MSCATTER:           return visitMSCATTER(N);
1809   case ISD::MSTORE:             return visitMSTORE(N);
1810   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
1811   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
1812   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
1813   case ISD::FP_TO_BF16:         return visitFP_TO_BF16(N);
1814   case ISD::FREEZE:             return visitFREEZE(N);
1815   case ISD::VECREDUCE_FADD:
1816   case ISD::VECREDUCE_FMUL:
1817   case ISD::VECREDUCE_ADD:
1818   case ISD::VECREDUCE_MUL:
1819   case ISD::VECREDUCE_AND:
1820   case ISD::VECREDUCE_OR:
1821   case ISD::VECREDUCE_XOR:
1822   case ISD::VECREDUCE_SMAX:
1823   case ISD::VECREDUCE_SMIN:
1824   case ISD::VECREDUCE_UMAX:
1825   case ISD::VECREDUCE_UMIN:
1826   case ISD::VECREDUCE_FMAX:
1827   case ISD::VECREDUCE_FMIN:     return visitVECREDUCE(N);
1828 #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
1829 #include "llvm/IR/VPIntrinsics.def"
1830     return visitVPOp(N);
1831   }
1832   return SDValue();
1833 }
1834 
combine(SDNode * N)1835 SDValue DAGCombiner::combine(SDNode *N) {
1836   SDValue RV;
1837   if (!DisableGenericCombines)
1838     RV = visit(N);
1839 
1840   // If nothing happened, try a target-specific DAG combine.
1841   if (!RV.getNode()) {
1842     assert(N->getOpcode() != ISD::DELETED_NODE &&
1843            "Node was deleted but visit returned NULL!");
1844 
1845     if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
1846         TLI.hasTargetDAGCombine((ISD::NodeType)N->getOpcode())) {
1847 
1848       // Expose the DAG combiner to the target combiner impls.
1849       TargetLowering::DAGCombinerInfo
1850         DagCombineInfo(DAG, Level, false, this);
1851 
1852       RV = TLI.PerformDAGCombine(N, DagCombineInfo);
1853     }
1854   }
1855 
1856   // If nothing happened still, try promoting the operation.
1857   if (!RV.getNode()) {
1858     switch (N->getOpcode()) {
1859     default: break;
1860     case ISD::ADD:
1861     case ISD::SUB:
1862     case ISD::MUL:
1863     case ISD::AND:
1864     case ISD::OR:
1865     case ISD::XOR:
1866       RV = PromoteIntBinOp(SDValue(N, 0));
1867       break;
1868     case ISD::SHL:
1869     case ISD::SRA:
1870     case ISD::SRL:
1871       RV = PromoteIntShiftOp(SDValue(N, 0));
1872       break;
1873     case ISD::SIGN_EXTEND:
1874     case ISD::ZERO_EXTEND:
1875     case ISD::ANY_EXTEND:
1876       RV = PromoteExtend(SDValue(N, 0));
1877       break;
1878     case ISD::LOAD:
1879       if (PromoteLoad(SDValue(N, 0)))
1880         RV = SDValue(N, 0);
1881       break;
1882     }
1883   }
1884 
1885   // If N is a commutative binary node, try to eliminate it if the commuted
1886   // version is already present in the DAG.
1887   if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode())) {
1888     SDValue N0 = N->getOperand(0);
1889     SDValue N1 = N->getOperand(1);
1890 
1891     // Constant operands are canonicalized to RHS.
1892     if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
1893       SDValue Ops[] = {N1, N0};
1894       SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
1895                                             N->getFlags());
1896       if (CSENode)
1897         return SDValue(CSENode, 0);
1898     }
1899   }
1900 
1901   return RV;
1902 }
1903 
1904 /// Given a node, return its input chain if it has one, otherwise return a null
1905 /// sd operand.
getInputChainForNode(SDNode * N)1906 static SDValue getInputChainForNode(SDNode *N) {
1907   if (unsigned NumOps = N->getNumOperands()) {
1908     if (N->getOperand(0).getValueType() == MVT::Other)
1909       return N->getOperand(0);
1910     if (N->getOperand(NumOps-1).getValueType() == MVT::Other)
1911       return N->getOperand(NumOps-1);
1912     for (unsigned i = 1; i < NumOps-1; ++i)
1913       if (N->getOperand(i).getValueType() == MVT::Other)
1914         return N->getOperand(i);
1915   }
1916   return SDValue();
1917 }
1918 
visitTokenFactor(SDNode * N)1919 SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
1920   // If N has two operands, where one has an input chain equal to the other,
1921   // the 'other' chain is redundant.
1922   if (N->getNumOperands() == 2) {
1923     if (getInputChainForNode(N->getOperand(0).getNode()) == N->getOperand(1))
1924       return N->getOperand(0);
1925     if (getInputChainForNode(N->getOperand(1).getNode()) == N->getOperand(0))
1926       return N->getOperand(1);
1927   }
1928 
1929   // Don't simplify token factors if optnone.
1930   if (OptLevel == CodeGenOpt::None)
1931     return SDValue();
1932 
1933   // Don't simplify the token factor if the node itself has too many operands.
1934   if (N->getNumOperands() > TokenFactorInlineLimit)
1935     return SDValue();
1936 
1937   // If the sole user is a token factor, we should make sure we have a
1938   // chance to merge them together. This prevents TF chains from inhibiting
1939   // optimizations.
1940   if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::TokenFactor)
1941     AddToWorklist(*(N->use_begin()));
1942 
1943   SmallVector<SDNode *, 8> TFs;     // List of token factors to visit.
1944   SmallVector<SDValue, 8> Ops;      // Ops for replacing token factor.
1945   SmallPtrSet<SDNode*, 16> SeenOps;
1946   bool Changed = false;             // If we should replace this token factor.
1947 
1948   // Start out with this token factor.
1949   TFs.push_back(N);
1950 
1951   // Iterate through token factors.  The TFs grows when new token factors are
1952   // encountered.
1953   for (unsigned i = 0; i < TFs.size(); ++i) {
1954     // Limit number of nodes to inline, to avoid quadratic compile times.
1955     // We have to add the outstanding Token Factors to Ops, otherwise we might
1956     // drop Ops from the resulting Token Factors.
1957     if (Ops.size() > TokenFactorInlineLimit) {
1958       for (unsigned j = i; j < TFs.size(); j++)
1959         Ops.emplace_back(TFs[j], 0);
1960       // Drop unprocessed Token Factors from TFs, so we do not add them to the
1961       // combiner worklist later.
1962       TFs.resize(i);
1963       break;
1964     }
1965 
1966     SDNode *TF = TFs[i];
1967     // Check each of the operands.
1968     for (const SDValue &Op : TF->op_values()) {
1969       switch (Op.getOpcode()) {
1970       case ISD::EntryToken:
1971         // Entry tokens don't need to be added to the list. They are
1972         // redundant.
1973         Changed = true;
1974         break;
1975 
1976       case ISD::TokenFactor:
1977         if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
1978           // Queue up for processing.
1979           TFs.push_back(Op.getNode());
1980           Changed = true;
1981           break;
1982         }
1983         [[fallthrough]];
1984 
1985       default:
1986         // Only add if it isn't already in the list.
1987         if (SeenOps.insert(Op.getNode()).second)
1988           Ops.push_back(Op);
1989         else
1990           Changed = true;
1991         break;
1992       }
1993     }
1994   }
1995 
1996   // Re-visit inlined Token Factors, to clean them up in case they have been
1997   // removed. Skip the first Token Factor, as this is the current node.
1998   for (unsigned i = 1, e = TFs.size(); i < e; i++)
1999     AddToWorklist(TFs[i]);
2000 
2001   // Remove Nodes that are chained to another node in the list. Do so
2002   // by walking up chains breath-first stopping when we've seen
2003   // another operand. In general we must climb to the EntryNode, but we can exit
2004   // early if we find all remaining work is associated with just one operand as
2005   // no further pruning is possible.
2006 
2007   // List of nodes to search through and original Ops from which they originate.
2008   SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
2009   SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
2010   SmallPtrSet<SDNode *, 16> SeenChains;
2011   bool DidPruneOps = false;
2012 
2013   unsigned NumLeftToConsider = 0;
2014   for (const SDValue &Op : Ops) {
2015     Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
2016     OpWorkCount.push_back(1);
2017   }
2018 
2019   auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2020     // If this is an Op, we can remove the op from the list. Remark any
2021     // search associated with it as from the current OpNumber.
2022     if (SeenOps.contains(Op)) {
2023       Changed = true;
2024       DidPruneOps = true;
2025       unsigned OrigOpNumber = 0;
2026       while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2027         OrigOpNumber++;
2028       assert((OrigOpNumber != Ops.size()) &&
2029              "expected to find TokenFactor Operand");
2030       // Re-mark worklist from OrigOpNumber to OpNumber
2031       for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2032         if (Worklist[i].second == OrigOpNumber) {
2033           Worklist[i].second = OpNumber;
2034         }
2035       }
2036       OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2037       OpWorkCount[OrigOpNumber] = 0;
2038       NumLeftToConsider--;
2039     }
2040     // Add if it's a new chain
2041     if (SeenChains.insert(Op).second) {
2042       OpWorkCount[OpNumber]++;
2043       Worklist.push_back(std::make_pair(Op, OpNumber));
2044     }
2045   };
2046 
2047   for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2048     // We need at least be consider at least 2 Ops to prune.
2049     if (NumLeftToConsider <= 1)
2050       break;
2051     auto CurNode = Worklist[i].first;
2052     auto CurOpNumber = Worklist[i].second;
2053     assert((OpWorkCount[CurOpNumber] > 0) &&
2054            "Node should not appear in worklist");
2055     switch (CurNode->getOpcode()) {
2056     case ISD::EntryToken:
2057       // Hitting EntryToken is the only way for the search to terminate without
2058       // hitting
2059       // another operand's search. Prevent us from marking this operand
2060       // considered.
2061       NumLeftToConsider++;
2062       break;
2063     case ISD::TokenFactor:
2064       for (const SDValue &Op : CurNode->op_values())
2065         AddToWorklist(i, Op.getNode(), CurOpNumber);
2066       break;
2067     case ISD::LIFETIME_START:
2068     case ISD::LIFETIME_END:
2069     case ISD::CopyFromReg:
2070     case ISD::CopyToReg:
2071       AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
2072       break;
2073     default:
2074       if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
2075         AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2076       break;
2077     }
2078     OpWorkCount[CurOpNumber]--;
2079     if (OpWorkCount[CurOpNumber] == 0)
2080       NumLeftToConsider--;
2081   }
2082 
2083   // If we've changed things around then replace token factor.
2084   if (Changed) {
2085     SDValue Result;
2086     if (Ops.empty()) {
2087       // The entry token is the only possible outcome.
2088       Result = DAG.getEntryNode();
2089     } else {
2090       if (DidPruneOps) {
2091         SmallVector<SDValue, 8> PrunedOps;
2092         //
2093         for (const SDValue &Op : Ops) {
2094           if (SeenChains.count(Op.getNode()) == 0)
2095             PrunedOps.push_back(Op);
2096         }
2097         Result = DAG.getTokenFactor(SDLoc(N), PrunedOps);
2098       } else {
2099         Result = DAG.getTokenFactor(SDLoc(N), Ops);
2100       }
2101     }
2102     return Result;
2103   }
2104   return SDValue();
2105 }
2106 
2107 /// MERGE_VALUES can always be eliminated.
visitMERGE_VALUES(SDNode * N)2108 SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2109   WorklistRemover DeadNodes(*this);
2110   // Replacing results may cause a different MERGE_VALUES to suddenly
2111   // be CSE'd with N, and carry its uses with it. Iterate until no
2112   // uses remain, to ensure that the node can be safely deleted.
2113   // First add the users of this node to the work list so that they
2114   // can be tried again once they have new operands.
2115   AddUsersToWorklist(N);
2116   do {
2117     // Do as a single replacement to avoid rewalking use lists.
2118     SmallVector<SDValue, 8> Ops;
2119     for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2120       Ops.push_back(N->getOperand(i));
2121     DAG.ReplaceAllUsesWith(N, Ops.data());
2122   } while (!N->use_empty());
2123   deleteAndRecombine(N);
2124   return SDValue(N, 0);   // Return N so it doesn't get rechecked!
2125 }
2126 
2127 /// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2128 /// ConstantSDNode pointer else nullptr.
getAsNonOpaqueConstant(SDValue N)2129 static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2130   ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N);
2131   return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2132 }
2133 
2134 /// Return true if 'Use' is a load or a store that uses N as its base pointer
2135 /// and that N may be folded in the load / store addressing mode.
canFoldInAddressingMode(SDNode * N,SDNode * Use,SelectionDAG & DAG,const TargetLowering & TLI)2136 static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2137                                     const TargetLowering &TLI) {
2138   EVT VT;
2139   unsigned AS;
2140 
2141   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Use)) {
2142     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2143       return false;
2144     VT = LD->getMemoryVT();
2145     AS = LD->getAddressSpace();
2146   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Use)) {
2147     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2148       return false;
2149     VT = ST->getMemoryVT();
2150     AS = ST->getAddressSpace();
2151   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Use)) {
2152     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2153       return false;
2154     VT = LD->getMemoryVT();
2155     AS = LD->getAddressSpace();
2156   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Use)) {
2157     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2158       return false;
2159     VT = ST->getMemoryVT();
2160     AS = ST->getAddressSpace();
2161   } else {
2162     return false;
2163   }
2164 
2165   TargetLowering::AddrMode AM;
2166   if (N->getOpcode() == ISD::ADD) {
2167     AM.HasBaseReg = true;
2168     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2169     if (Offset)
2170       // [reg +/- imm]
2171       AM.BaseOffs = Offset->getSExtValue();
2172     else
2173       // [reg +/- reg]
2174       AM.Scale = 1;
2175   } else if (N->getOpcode() == ISD::SUB) {
2176     AM.HasBaseReg = true;
2177     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2178     if (Offset)
2179       // [reg +/- imm]
2180       AM.BaseOffs = -Offset->getSExtValue();
2181     else
2182       // [reg +/- reg]
2183       AM.Scale = 1;
2184   } else {
2185     return false;
2186   }
2187 
2188   return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM,
2189                                    VT.getTypeForEVT(*DAG.getContext()), AS);
2190 }
2191 
2192 /// This inverts a canonicalization in IR that replaces a variable select arm
2193 /// with an identity constant. Codegen improves if we re-use the variable
2194 /// operand rather than load a constant. This can also be converted into a
2195 /// masked vector operation if the target supports it.
foldSelectWithIdentityConstant(SDNode * N,SelectionDAG & DAG,bool ShouldCommuteOperands)2196 static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
2197                                               bool ShouldCommuteOperands) {
2198   // Match a select as operand 1. The identity constant that we are looking for
2199   // is only valid as operand 1 of a non-commutative binop.
2200   SDValue N0 = N->getOperand(0);
2201   SDValue N1 = N->getOperand(1);
2202   if (ShouldCommuteOperands)
2203     std::swap(N0, N1);
2204 
2205   // TODO: Should this apply to scalar select too?
2206   if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse())
2207     return SDValue();
2208 
2209   // We can't hoist div/rem because of immediate UB (not speculatable).
2210   unsigned Opcode = N->getOpcode();
2211   if (!DAG.isSafeToSpeculativelyExecute(Opcode))
2212     return SDValue();
2213 
2214   EVT VT = N->getValueType(0);
2215   SDValue Cond = N1.getOperand(0);
2216   SDValue TVal = N1.getOperand(1);
2217   SDValue FVal = N1.getOperand(2);
2218 
2219   // This transform increases uses of N0, so freeze it to be safe.
2220   // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2221   unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2222   if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo)) {
2223     SDValue F0 = DAG.getFreeze(N0);
2224     SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
2225     return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
2226   }
2227   // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2228   if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo)) {
2229     SDValue F0 = DAG.getFreeze(N0);
2230     SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
2231     return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
2232   }
2233 
2234   return SDValue();
2235 }
2236 
foldBinOpIntoSelect(SDNode * BO)2237 SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2238   assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2239          "Unexpected binary operator");
2240 
2241   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2242   auto BinOpcode = BO->getOpcode();
2243   EVT VT = BO->getValueType(0);
2244   if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
2245     if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
2246       return Sel;
2247 
2248     if (TLI.isCommutativeBinOp(BO->getOpcode()))
2249       if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
2250         return Sel;
2251   }
2252 
2253   // Don't do this unless the old select is going away. We want to eliminate the
2254   // binary operator, not replace a binop with a select.
2255   // TODO: Handle ISD::SELECT_CC.
2256   unsigned SelOpNo = 0;
2257   SDValue Sel = BO->getOperand(0);
2258   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2259     SelOpNo = 1;
2260     Sel = BO->getOperand(1);
2261   }
2262 
2263   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2264     return SDValue();
2265 
2266   SDValue CT = Sel.getOperand(1);
2267   if (!isConstantOrConstantVector(CT, true) &&
2268       !DAG.isConstantFPBuildVectorOrConstantFP(CT))
2269     return SDValue();
2270 
2271   SDValue CF = Sel.getOperand(2);
2272   if (!isConstantOrConstantVector(CF, true) &&
2273       !DAG.isConstantFPBuildVectorOrConstantFP(CF))
2274     return SDValue();
2275 
2276   // Bail out if any constants are opaque because we can't constant fold those.
2277   // The exception is "and" and "or" with either 0 or -1 in which case we can
2278   // propagate non constant operands into select. I.e.:
2279   // and (select Cond, 0, -1), X --> select Cond, 0, X
2280   // or X, (select Cond, -1, 0) --> select Cond, -1, X
2281   bool CanFoldNonConst =
2282       (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2283       ((isNullOrNullSplat(CT) && isAllOnesOrAllOnesSplat(CF)) ||
2284        (isNullOrNullSplat(CF) && isAllOnesOrAllOnesSplat(CT)));
2285 
2286   SDValue CBO = BO->getOperand(SelOpNo ^ 1);
2287   if (!CanFoldNonConst &&
2288       !isConstantOrConstantVector(CBO, true) &&
2289       !DAG.isConstantFPBuildVectorOrConstantFP(CBO))
2290     return SDValue();
2291 
2292   SDLoc DL(Sel);
2293   SDValue NewCT, NewCF;
2294 
2295   if (CanFoldNonConst) {
2296     // If CBO is an opaque constant, we can't rely on getNode to constant fold.
2297     if ((BinOpcode == ISD::AND && isNullOrNullSplat(CT)) ||
2298         (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CT)))
2299       NewCT = CT;
2300     else
2301       NewCT = CBO;
2302 
2303     if ((BinOpcode == ISD::AND && isNullOrNullSplat(CF)) ||
2304         (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CF)))
2305       NewCF = CF;
2306     else
2307       NewCF = CBO;
2308   } else {
2309     // We have a select-of-constants followed by a binary operator with a
2310     // constant. Eliminate the binop by pulling the constant math into the
2311     // select. Example: add (select Cond, CT, CF), CBO --> select Cond, CT +
2312     // CBO, CF + CBO
2313     NewCT = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CT)
2314                     : DAG.getNode(BinOpcode, DL, VT, CT, CBO);
2315     if (!CanFoldNonConst && !NewCT.isUndef() &&
2316         !isConstantOrConstantVector(NewCT, true) &&
2317         !DAG.isConstantFPBuildVectorOrConstantFP(NewCT))
2318       return SDValue();
2319 
2320     NewCF = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CF)
2321                     : DAG.getNode(BinOpcode, DL, VT, CF, CBO);
2322     if (!CanFoldNonConst && !NewCF.isUndef() &&
2323         !isConstantOrConstantVector(NewCF, true) &&
2324         !DAG.isConstantFPBuildVectorOrConstantFP(NewCF))
2325       return SDValue();
2326   }
2327 
2328   SDValue SelectOp = DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF);
2329   SelectOp->setFlags(BO->getFlags());
2330   return SelectOp;
2331 }
2332 
foldAddSubBoolOfMaskedVal(SDNode * N,SelectionDAG & DAG)2333 static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) {
2334   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2335          "Expecting add or sub");
2336 
2337   // Match a constant operand and a zext operand for the math instruction:
2338   // add Z, C
2339   // sub C, Z
2340   bool IsAdd = N->getOpcode() == ISD::ADD;
2341   SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
2342   SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
2343   auto *CN = dyn_cast<ConstantSDNode>(C);
2344   if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2345     return SDValue();
2346 
2347   // Match the zext operand as a setcc of a boolean.
2348   if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
2349       Z.getOperand(0).getValueType() != MVT::i1)
2350     return SDValue();
2351 
2352   // Match the compare as: setcc (X & 1), 0, eq.
2353   SDValue SetCC = Z.getOperand(0);
2354   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
2355   if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
2356       SetCC.getOperand(0).getOpcode() != ISD::AND ||
2357       !isOneConstant(SetCC.getOperand(0).getOperand(1)))
2358     return SDValue();
2359 
2360   // We are adding/subtracting a constant and an inverted low bit. Turn that
2361   // into a subtract/add of the low bit with incremented/decremented constant:
2362   // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2363   // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2364   EVT VT = C.getValueType();
2365   SDLoc DL(N);
2366   SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
2367   SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
2368                        DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2369   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2370 }
2371 
2372 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2373 /// a shift and add with a different constant.
foldAddSubOfSignBit(SDNode * N,SelectionDAG & DAG)2374 static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) {
2375   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2376          "Expecting add or sub");
2377 
2378   // We need a constant operand for the add/sub, and the other operand is a
2379   // logical shift right: add (srl), C or sub C, (srl).
2380   bool IsAdd = N->getOpcode() == ISD::ADD;
2381   SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
2382   SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
2383   if (!DAG.isConstantIntBuildVectorOrConstantInt(ConstantOp) ||
2384       ShiftOp.getOpcode() != ISD::SRL)
2385     return SDValue();
2386 
2387   // The shift must be of a 'not' value.
2388   SDValue Not = ShiftOp.getOperand(0);
2389   if (!Not.hasOneUse() || !isBitwiseNot(Not))
2390     return SDValue();
2391 
2392   // The shift must be moving the sign bit to the least-significant-bit.
2393   EVT VT = ShiftOp.getValueType();
2394   SDValue ShAmt = ShiftOp.getOperand(1);
2395   ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
2396   if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2397     return SDValue();
2398 
2399   // Eliminate the 'not' by adjusting the shift and add/sub constant:
2400   // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2401   // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2402   SDLoc DL(N);
2403   if (SDValue NewC = DAG.FoldConstantArithmetic(
2404           IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2405           {ConstantOp, DAG.getConstant(1, DL, VT)})) {
2406     SDValue NewShift = DAG.getNode(IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2407                                    Not.getOperand(0), ShAmt);
2408     return DAG.getNode(ISD::ADD, DL, VT, NewShift, NewC);
2409   }
2410 
2411   return SDValue();
2412 }
2413 
isADDLike(SDValue V,const SelectionDAG & DAG)2414 static bool isADDLike(SDValue V, const SelectionDAG &DAG) {
2415   unsigned Opcode = V.getOpcode();
2416   if (Opcode == ISD::OR)
2417     return DAG.haveNoCommonBitsSet(V.getOperand(0), V.getOperand(1));
2418   if (Opcode == ISD::XOR)
2419     return isMinSignedConstant(V.getOperand(1));
2420   return false;
2421 }
2422 
2423 /// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2424 /// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2425 /// are no common bits set in the operands).
visitADDLike(SDNode * N)2426 SDValue DAGCombiner::visitADDLike(SDNode *N) {
2427   SDValue N0 = N->getOperand(0);
2428   SDValue N1 = N->getOperand(1);
2429   EVT VT = N0.getValueType();
2430   SDLoc DL(N);
2431 
2432   // fold (add x, undef) -> undef
2433   if (N0.isUndef())
2434     return N0;
2435   if (N1.isUndef())
2436     return N1;
2437 
2438   // fold (add c1, c2) -> c1+c2
2439   if (SDValue C = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1}))
2440     return C;
2441 
2442   // canonicalize constant to RHS
2443   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2444       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2445     return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
2446 
2447   // fold vector ops
2448   if (VT.isVector()) {
2449     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2450       return FoldedVOp;
2451 
2452     // fold (add x, 0) -> x, vector edition
2453     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2454       return N0;
2455   }
2456 
2457   // fold (add x, 0) -> x
2458   if (isNullConstant(N1))
2459     return N0;
2460 
2461   if (N0.getOpcode() == ISD::SUB) {
2462     SDValue N00 = N0.getOperand(0);
2463     SDValue N01 = N0.getOperand(1);
2464 
2465     // fold ((A-c1)+c2) -> (A+(c2-c1))
2466     if (SDValue Sub = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N01}))
2467       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub);
2468 
2469     // fold ((c1-A)+c2) -> (c1+c2)-A
2470     if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N00}))
2471       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2472   }
2473 
2474   // add (sext i1 X), 1 -> zext (not i1 X)
2475   // We don't transform this pattern:
2476   //   add (zext i1 X), -1 -> sext (not i1 X)
2477   // because most (?) targets generate better code for the zext form.
2478   if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2479       isOneOrOneSplat(N1)) {
2480     SDValue X = N0.getOperand(0);
2481     if ((!LegalOperations ||
2482          (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
2483           TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
2484         X.getScalarValueSizeInBits() == 1) {
2485       SDValue Not = DAG.getNOT(DL, X, X.getValueType());
2486       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
2487     }
2488   }
2489 
2490   // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2491   // iff (or x, c0) is equivalent to (add x, c0).
2492   // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2493   // iff (xor x, c0) is equivalent to (add x, c0).
2494   if (isADDLike(N0, DAG)) {
2495     SDValue N01 = N0.getOperand(1);
2496     if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N01}))
2497       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add);
2498   }
2499 
2500   if (SDValue NewSel = foldBinOpIntoSelect(N))
2501     return NewSel;
2502 
2503   // reassociate add
2504   if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N, N0, N1)) {
2505     if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
2506       return RADD;
2507 
2508     // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2509     // equivalent to (add x, c).
2510     // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2511     // equivalent to (add x, c).
2512     auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2513       if (isADDLike(N0, DAG) && N0.hasOneUse() &&
2514           isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) {
2515         return DAG.getNode(ISD::ADD, DL, VT,
2516                            DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)),
2517                            N0.getOperand(1));
2518       }
2519       return SDValue();
2520     };
2521     if (SDValue Add = ReassociateAddOr(N0, N1))
2522       return Add;
2523     if (SDValue Add = ReassociateAddOr(N1, N0))
2524       return Add;
2525   }
2526   // fold ((0-A) + B) -> B-A
2527   if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
2528     return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2529 
2530   // fold (A + (0-B)) -> A-B
2531   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
2532     return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1));
2533 
2534   // fold (A+(B-A)) -> B
2535   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1))
2536     return N1.getOperand(0);
2537 
2538   // fold ((B-A)+A) -> B
2539   if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1))
2540     return N0.getOperand(0);
2541 
2542   // fold ((A-B)+(C-A)) -> (C-B)
2543   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2544       N0.getOperand(0) == N1.getOperand(1))
2545     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2546                        N0.getOperand(1));
2547 
2548   // fold ((A-B)+(B-C)) -> (A-C)
2549   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2550       N0.getOperand(1) == N1.getOperand(0))
2551     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
2552                        N1.getOperand(1));
2553 
2554   // fold (A+(B-(A+C))) to (B-C)
2555   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2556       N0 == N1.getOperand(1).getOperand(0))
2557     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2558                        N1.getOperand(1).getOperand(1));
2559 
2560   // fold (A+(B-(C+A))) to (B-C)
2561   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2562       N0 == N1.getOperand(1).getOperand(1))
2563     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2564                        N1.getOperand(1).getOperand(0));
2565 
2566   // fold (A+((B-A)+or-C)) to (B+or-C)
2567   if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
2568       N1.getOperand(0).getOpcode() == ISD::SUB &&
2569       N0 == N1.getOperand(0).getOperand(1))
2570     return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0),
2571                        N1.getOperand(1));
2572 
2573   // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2574   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2575       N0->hasOneUse() && N1->hasOneUse()) {
2576     SDValue N00 = N0.getOperand(0);
2577     SDValue N01 = N0.getOperand(1);
2578     SDValue N10 = N1.getOperand(0);
2579     SDValue N11 = N1.getOperand(1);
2580 
2581     if (isConstantOrConstantVector(N00) || isConstantOrConstantVector(N10))
2582       return DAG.getNode(ISD::SUB, DL, VT,
2583                          DAG.getNode(ISD::ADD, SDLoc(N0), VT, N00, N10),
2584                          DAG.getNode(ISD::ADD, SDLoc(N1), VT, N01, N11));
2585   }
2586 
2587   // fold (add (umax X, C), -C) --> (usubsat X, C)
2588   if (N0.getOpcode() == ISD::UMAX && hasOperation(ISD::USUBSAT, VT)) {
2589     auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2590       return (!Max && !Op) ||
2591              (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2592     };
2593     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchUSUBSAT,
2594                                   /*AllowUndefs*/ true))
2595       return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0),
2596                          N0.getOperand(1));
2597   }
2598 
2599   if (SimplifyDemandedBits(SDValue(N, 0)))
2600     return SDValue(N, 0);
2601 
2602   if (isOneOrOneSplat(N1)) {
2603     // fold (add (xor a, -1), 1) -> (sub 0, a)
2604     if (isBitwiseNot(N0))
2605       return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
2606                          N0.getOperand(0));
2607 
2608     // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2609     if (N0.getOpcode() == ISD::ADD) {
2610       SDValue A, Xor;
2611 
2612       if (isBitwiseNot(N0.getOperand(0))) {
2613         A = N0.getOperand(1);
2614         Xor = N0.getOperand(0);
2615       } else if (isBitwiseNot(N0.getOperand(1))) {
2616         A = N0.getOperand(0);
2617         Xor = N0.getOperand(1);
2618       }
2619 
2620       if (Xor)
2621         return DAG.getNode(ISD::SUB, DL, VT, A, Xor.getOperand(0));
2622     }
2623 
2624     // Look for:
2625     //   add (add x, y), 1
2626     // And if the target does not like this form then turn into:
2627     //   sub y, (xor x, -1)
2628     if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
2629         N0.hasOneUse()) {
2630       SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2631                                 DAG.getAllOnesConstant(DL, VT));
2632       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not);
2633     }
2634   }
2635 
2636   // (x - y) + -1  ->  add (xor y, -1), x
2637   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
2638       isAllOnesOrAllOnesSplat(N1)) {
2639     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1), N1);
2640     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
2641   }
2642 
2643   if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
2644     return Combined;
2645 
2646   if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
2647     return Combined;
2648 
2649   return SDValue();
2650 }
2651 
visitADD(SDNode * N)2652 SDValue DAGCombiner::visitADD(SDNode *N) {
2653   SDValue N0 = N->getOperand(0);
2654   SDValue N1 = N->getOperand(1);
2655   EVT VT = N0.getValueType();
2656   SDLoc DL(N);
2657 
2658   if (SDValue Combined = visitADDLike(N))
2659     return Combined;
2660 
2661   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
2662     return V;
2663 
2664   if (SDValue V = foldAddSubOfSignBit(N, DAG))
2665     return V;
2666 
2667   // fold (a+b) -> (a|b) iff a and b share no bits.
2668   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
2669       DAG.haveNoCommonBitsSet(N0, N1))
2670     return DAG.getNode(ISD::OR, DL, VT, N0, N1);
2671 
2672   // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
2673   if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
2674     const APInt &C0 = N0->getConstantOperandAPInt(0);
2675     const APInt &C1 = N1->getConstantOperandAPInt(0);
2676     return DAG.getVScale(DL, VT, C0 + C1);
2677   }
2678 
2679   // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
2680   if (N0.getOpcode() == ISD::ADD &&
2681       N0.getOperand(1).getOpcode() == ISD::VSCALE &&
2682       N1.getOpcode() == ISD::VSCALE) {
2683     const APInt &VS0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2684     const APInt &VS1 = N1->getConstantOperandAPInt(0);
2685     SDValue VS = DAG.getVScale(DL, VT, VS0 + VS1);
2686     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
2687   }
2688 
2689   // Fold (add step_vector(c1), step_vector(c2)  to step_vector(c1+c2))
2690   if (N0.getOpcode() == ISD::STEP_VECTOR &&
2691       N1.getOpcode() == ISD::STEP_VECTOR) {
2692     const APInt &C0 = N0->getConstantOperandAPInt(0);
2693     const APInt &C1 = N1->getConstantOperandAPInt(0);
2694     APInt NewStep = C0 + C1;
2695     return DAG.getStepVector(DL, VT, NewStep);
2696   }
2697 
2698   // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
2699   if (N0.getOpcode() == ISD::ADD &&
2700       N0.getOperand(1).getOpcode() == ISD::STEP_VECTOR &&
2701       N1.getOpcode() == ISD::STEP_VECTOR) {
2702     const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2703     const APInt &SV1 = N1->getConstantOperandAPInt(0);
2704     APInt NewStep = SV0 + SV1;
2705     SDValue SV = DAG.getStepVector(DL, VT, NewStep);
2706     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
2707   }
2708 
2709   return SDValue();
2710 }
2711 
visitADDSAT(SDNode * N)2712 SDValue DAGCombiner::visitADDSAT(SDNode *N) {
2713   unsigned Opcode = N->getOpcode();
2714   SDValue N0 = N->getOperand(0);
2715   SDValue N1 = N->getOperand(1);
2716   EVT VT = N0.getValueType();
2717   SDLoc DL(N);
2718 
2719   // fold (add_sat x, undef) -> -1
2720   if (N0.isUndef() || N1.isUndef())
2721     return DAG.getAllOnesConstant(DL, VT);
2722 
2723   // fold (add_sat c1, c2) -> c3
2724   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
2725     return C;
2726 
2727   // canonicalize constant to RHS
2728   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2729       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2730     return DAG.getNode(Opcode, DL, VT, N1, N0);
2731 
2732   // fold vector ops
2733   if (VT.isVector()) {
2734     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2735       return FoldedVOp;
2736 
2737     // fold (add_sat x, 0) -> x, vector edition
2738     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2739       return N0;
2740   }
2741 
2742   // fold (add_sat x, 0) -> x
2743   if (isNullConstant(N1))
2744     return N0;
2745 
2746   // If it cannot overflow, transform into an add.
2747   if (Opcode == ISD::UADDSAT)
2748     if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2749       return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
2750 
2751   return SDValue();
2752 }
2753 
getAsCarry(const TargetLowering & TLI,SDValue V)2754 static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) {
2755   bool Masked = false;
2756 
2757   // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
2758   while (true) {
2759     if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
2760       V = V.getOperand(0);
2761       continue;
2762     }
2763 
2764     if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
2765       Masked = true;
2766       V = V.getOperand(0);
2767       continue;
2768     }
2769 
2770     break;
2771   }
2772 
2773   // If this is not a carry, return.
2774   if (V.getResNo() != 1)
2775     return SDValue();
2776 
2777   if (V.getOpcode() != ISD::ADDCARRY && V.getOpcode() != ISD::SUBCARRY &&
2778       V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
2779     return SDValue();
2780 
2781   EVT VT = V->getValueType(0);
2782   if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT))
2783     return SDValue();
2784 
2785   // If the result is masked, then no matter what kind of bool it is we can
2786   // return. If it isn't, then we need to make sure the bool type is either 0 or
2787   // 1 and not other values.
2788   if (Masked ||
2789       TLI.getBooleanContents(V.getValueType()) ==
2790           TargetLoweringBase::ZeroOrOneBooleanContent)
2791     return V;
2792 
2793   return SDValue();
2794 }
2795 
2796 /// Given the operands of an add/sub operation, see if the 2nd operand is a
2797 /// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
2798 /// the opcode and bypass the mask operation.
foldAddSubMasked1(bool IsAdd,SDValue N0,SDValue N1,SelectionDAG & DAG,const SDLoc & DL)2799 static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
2800                                  SelectionDAG &DAG, const SDLoc &DL) {
2801   if (N1.getOpcode() == ISD::ZERO_EXTEND)
2802     N1 = N1.getOperand(0);
2803 
2804   if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(N1->getOperand(1)))
2805     return SDValue();
2806 
2807   EVT VT = N0.getValueType();
2808   SDValue N10 = N1.getOperand(0);
2809   if (N10.getValueType() != VT && N10.getOpcode() == ISD::TRUNCATE)
2810     N10 = N10.getOperand(0);
2811 
2812   if (N10.getValueType() != VT)
2813     return SDValue();
2814 
2815   if (DAG.ComputeNumSignBits(N10) != VT.getScalarSizeInBits())
2816     return SDValue();
2817 
2818   // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
2819   // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
2820   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N0, N10);
2821 }
2822 
2823 /// Helper for doing combines based on N0 and N1 being added to each other.
visitADDLikeCommutative(SDValue N0,SDValue N1,SDNode * LocReference)2824 SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
2825                                           SDNode *LocReference) {
2826   EVT VT = N0.getValueType();
2827   SDLoc DL(LocReference);
2828 
2829   // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
2830   if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB &&
2831       isNullOrNullSplat(N1.getOperand(0).getOperand(0)))
2832     return DAG.getNode(ISD::SUB, DL, VT, N0,
2833                        DAG.getNode(ISD::SHL, DL, VT,
2834                                    N1.getOperand(0).getOperand(1),
2835                                    N1.getOperand(1)));
2836 
2837   if (SDValue V = foldAddSubMasked1(true, N0, N1, DAG, DL))
2838     return V;
2839 
2840   // Look for:
2841   //   add (add x, 1), y
2842   // And if the target does not like this form then turn into:
2843   //   sub y, (xor x, -1)
2844   if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
2845       N0.hasOneUse() && isOneOrOneSplat(N0.getOperand(1))) {
2846     SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2847                               DAG.getAllOnesConstant(DL, VT));
2848     return DAG.getNode(ISD::SUB, DL, VT, N1, Not);
2849   }
2850 
2851   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
2852     // Hoist one-use subtraction by non-opaque constant:
2853     //   (x - C) + y  ->  (x + y) - C
2854     // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
2855     if (isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
2856       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1);
2857       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2858     }
2859     // Hoist one-use subtraction from non-opaque constant:
2860     //   (C - x) + y  ->  (y - x) + C
2861     if (isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
2862       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2863       return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0));
2864     }
2865   }
2866 
2867   // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
2868   // rather than 'add 0/-1' (the zext should get folded).
2869   // add (sext i1 Y), X --> sub X, (zext i1 Y)
2870   if (N0.getOpcode() == ISD::SIGN_EXTEND &&
2871       N0.getOperand(0).getScalarValueSizeInBits() == 1 &&
2872       TLI.getBooleanContents(VT) == TargetLowering::ZeroOrOneBooleanContent) {
2873     SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
2874     return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
2875   }
2876 
2877   // add X, (sextinreg Y i1) -> sub X, (and Y 1)
2878   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
2879     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
2880     if (TN->getVT() == MVT::i1) {
2881       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
2882                                  DAG.getConstant(1, DL, VT));
2883       return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
2884     }
2885   }
2886 
2887   // (add X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
2888   if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1)) &&
2889       N1.getResNo() == 0)
2890     return DAG.getNode(ISD::ADDCARRY, DL, N1->getVTList(),
2891                        N0, N1.getOperand(0), N1.getOperand(2));
2892 
2893   // (add X, Carry) -> (addcarry X, 0, Carry)
2894   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
2895     if (SDValue Carry = getAsCarry(TLI, N1))
2896       return DAG.getNode(ISD::ADDCARRY, DL,
2897                          DAG.getVTList(VT, Carry.getValueType()), N0,
2898                          DAG.getConstant(0, DL, VT), Carry);
2899 
2900   return SDValue();
2901 }
2902 
visitADDC(SDNode * N)2903 SDValue DAGCombiner::visitADDC(SDNode *N) {
2904   SDValue N0 = N->getOperand(0);
2905   SDValue N1 = N->getOperand(1);
2906   EVT VT = N0.getValueType();
2907   SDLoc DL(N);
2908 
2909   // If the flag result is dead, turn this into an ADD.
2910   if (!N->hasAnyUseOfValue(1))
2911     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2912                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2913 
2914   // canonicalize constant to RHS.
2915   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2916   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2917   if (N0C && !N1C)
2918     return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
2919 
2920   // fold (addc x, 0) -> x + no carry out
2921   if (isNullConstant(N1))
2922     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
2923                                         DL, MVT::Glue));
2924 
2925   // If it cannot overflow, transform into an add.
2926   if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2927     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2928                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2929 
2930   return SDValue();
2931 }
2932 
2933 /**
2934  * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
2935  * then the flip also occurs if computing the inverse is the same cost.
2936  * This function returns an empty SDValue in case it cannot flip the boolean
2937  * without increasing the cost of the computation. If you want to flip a boolean
2938  * no matter what, use DAG.getLogicalNOT.
2939  */
extractBooleanFlip(SDValue V,SelectionDAG & DAG,const TargetLowering & TLI,bool Force)2940 static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
2941                                   const TargetLowering &TLI,
2942                                   bool Force) {
2943   if (Force && isa<ConstantSDNode>(V))
2944     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
2945 
2946   if (V.getOpcode() != ISD::XOR)
2947     return SDValue();
2948 
2949   ConstantSDNode *Const = isConstOrConstSplat(V.getOperand(1), false);
2950   if (!Const)
2951     return SDValue();
2952 
2953   EVT VT = V.getValueType();
2954 
2955   bool IsFlip = false;
2956   switch(TLI.getBooleanContents(VT)) {
2957     case TargetLowering::ZeroOrOneBooleanContent:
2958       IsFlip = Const->isOne();
2959       break;
2960     case TargetLowering::ZeroOrNegativeOneBooleanContent:
2961       IsFlip = Const->isAllOnes();
2962       break;
2963     case TargetLowering::UndefinedBooleanContent:
2964       IsFlip = (Const->getAPIntValue() & 0x01) == 1;
2965       break;
2966   }
2967 
2968   if (IsFlip)
2969     return V.getOperand(0);
2970   if (Force)
2971     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
2972   return SDValue();
2973 }
2974 
visitADDO(SDNode * N)2975 SDValue DAGCombiner::visitADDO(SDNode *N) {
2976   SDValue N0 = N->getOperand(0);
2977   SDValue N1 = N->getOperand(1);
2978   EVT VT = N0.getValueType();
2979   bool IsSigned = (ISD::SADDO == N->getOpcode());
2980 
2981   EVT CarryVT = N->getValueType(1);
2982   SDLoc DL(N);
2983 
2984   // If the flag result is dead, turn this into an ADD.
2985   if (!N->hasAnyUseOfValue(1))
2986     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2987                      DAG.getUNDEF(CarryVT));
2988 
2989   // canonicalize constant to RHS.
2990   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2991       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2992     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
2993 
2994   // fold (addo x, 0) -> x + no carry out
2995   if (isNullOrNullSplat(N1))
2996     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
2997 
2998   if (!IsSigned) {
2999     // If it cannot overflow, transform into an add.
3000     if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
3001       return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3002                        DAG.getConstant(0, DL, CarryVT));
3003 
3004     // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
3005     if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
3006       SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
3007                                 DAG.getConstant(0, DL, VT), N0.getOperand(0));
3008       return CombineTo(
3009           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3010     }
3011 
3012     if (SDValue Combined = visitUADDOLike(N0, N1, N))
3013       return Combined;
3014 
3015     if (SDValue Combined = visitUADDOLike(N1, N0, N))
3016       return Combined;
3017   }
3018 
3019   return SDValue();
3020 }
3021 
visitUADDOLike(SDValue N0,SDValue N1,SDNode * N)3022 SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3023   EVT VT = N0.getValueType();
3024   if (VT.isVector())
3025     return SDValue();
3026 
3027   // (uaddo X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
3028   // If Y + 1 cannot overflow.
3029   if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1))) {
3030     SDValue Y = N1.getOperand(0);
3031     SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
3032     if (DAG.computeOverflowKind(Y, One) == SelectionDAG::OFK_Never)
3033       return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, Y,
3034                          N1.getOperand(2));
3035   }
3036 
3037   // (uaddo X, Carry) -> (addcarry X, 0, Carry)
3038   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
3039     if (SDValue Carry = getAsCarry(TLI, N1))
3040       return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0,
3041                          DAG.getConstant(0, SDLoc(N), VT), Carry);
3042 
3043   return SDValue();
3044 }
3045 
visitADDE(SDNode * N)3046 SDValue DAGCombiner::visitADDE(SDNode *N) {
3047   SDValue N0 = N->getOperand(0);
3048   SDValue N1 = N->getOperand(1);
3049   SDValue CarryIn = N->getOperand(2);
3050 
3051   // canonicalize constant to RHS
3052   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3053   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3054   if (N0C && !N1C)
3055     return DAG.getNode(ISD::ADDE, SDLoc(N), N->getVTList(),
3056                        N1, N0, CarryIn);
3057 
3058   // fold (adde x, y, false) -> (addc x, y)
3059   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3060     return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N0, N1);
3061 
3062   return SDValue();
3063 }
3064 
visitADDCARRY(SDNode * N)3065 SDValue DAGCombiner::visitADDCARRY(SDNode *N) {
3066   SDValue N0 = N->getOperand(0);
3067   SDValue N1 = N->getOperand(1);
3068   SDValue CarryIn = N->getOperand(2);
3069   SDLoc DL(N);
3070 
3071   // canonicalize constant to RHS
3072   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3073   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3074   if (N0C && !N1C)
3075     return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), N1, N0, CarryIn);
3076 
3077   // fold (addcarry x, y, false) -> (uaddo x, y)
3078   if (isNullConstant(CarryIn)) {
3079     if (!LegalOperations ||
3080         TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
3081       return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
3082   }
3083 
3084   // fold (addcarry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3085   if (isNullConstant(N0) && isNullConstant(N1)) {
3086     EVT VT = N0.getValueType();
3087     EVT CarryVT = CarryIn.getValueType();
3088     SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
3089     AddToWorklist(CarryExt.getNode());
3090     return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
3091                                     DAG.getConstant(1, DL, VT)),
3092                      DAG.getConstant(0, DL, CarryVT));
3093   }
3094 
3095   if (SDValue Combined = visitADDCARRYLike(N0, N1, CarryIn, N))
3096     return Combined;
3097 
3098   if (SDValue Combined = visitADDCARRYLike(N1, N0, CarryIn, N))
3099     return Combined;
3100 
3101   // We want to avoid useless duplication.
3102   // TODO: This is done automatically for binary operations. As ADDCARRY is
3103   // not a binary operation, this is not really possible to leverage this
3104   // existing mechanism for it. However, if more operations require the same
3105   // deduplication logic, then it may be worth generalize.
3106   SDValue Ops[] = {N1, N0, CarryIn};
3107   SDNode *CSENode =
3108       DAG.getNodeIfExists(ISD::ADDCARRY, N->getVTList(), Ops, N->getFlags());
3109   if (CSENode)
3110     return SDValue(CSENode, 0);
3111 
3112   return SDValue();
3113 }
3114 
visitSADDO_CARRY(SDNode * N)3115 SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3116   SDValue N0 = N->getOperand(0);
3117   SDValue N1 = N->getOperand(1);
3118   SDValue CarryIn = N->getOperand(2);
3119   SDLoc DL(N);
3120 
3121   // canonicalize constant to RHS
3122   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3123   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3124   if (N0C && !N1C)
3125     return DAG.getNode(ISD::SADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3126 
3127   // fold (saddo_carry x, y, false) -> (saddo x, y)
3128   if (isNullConstant(CarryIn)) {
3129     if (!LegalOperations ||
3130         TLI.isOperationLegalOrCustom(ISD::SADDO, N->getValueType(0)))
3131       return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, N1);
3132   }
3133 
3134   return SDValue();
3135 }
3136 
3137 /**
3138  * If we are facing some sort of diamond carry propapagtion pattern try to
3139  * break it up to generate something like:
3140  *   (addcarry X, 0, (addcarry A, B, Z):Carry)
3141  *
3142  * The end result is usually an increase in operation required, but because the
3143  * carry is now linearized, other transforms can kick in and optimize the DAG.
3144  *
3145  * Patterns typically look something like
3146  *            (uaddo A, B)
3147  *             /       \
3148  *          Carry      Sum
3149  *            |          \
3150  *            | (addcarry *, 0, Z)
3151  *            |       /
3152  *             \   Carry
3153  *              |   /
3154  * (addcarry X, *, *)
3155  *
3156  * But numerous variation exist. Our goal is to identify A, B, X and Z and
3157  * produce a combine with a single path for carry propagation.
3158  */
combineADDCARRYDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,SDValue X,SDValue Carry0,SDValue Carry1,SDNode * N)3159 static SDValue combineADDCARRYDiamond(DAGCombiner &Combiner, SelectionDAG &DAG,
3160                                       SDValue X, SDValue Carry0, SDValue Carry1,
3161                                       SDNode *N) {
3162   if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3163     return SDValue();
3164   if (Carry1.getOpcode() != ISD::UADDO)
3165     return SDValue();
3166 
3167   SDValue Z;
3168 
3169   /**
3170    * First look for a suitable Z. It will present itself in the form of
3171    * (addcarry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3172    */
3173   if (Carry0.getOpcode() == ISD::ADDCARRY &&
3174       isNullConstant(Carry0.getOperand(1))) {
3175     Z = Carry0.getOperand(2);
3176   } else if (Carry0.getOpcode() == ISD::UADDO &&
3177              isOneConstant(Carry0.getOperand(1))) {
3178     EVT VT = Combiner.getSetCCResultType(Carry0.getValueType());
3179     Z = DAG.getConstant(1, SDLoc(Carry0.getOperand(1)), VT);
3180   } else {
3181     // We couldn't find a suitable Z.
3182     return SDValue();
3183   }
3184 
3185 
3186   auto cancelDiamond = [&](SDValue A,SDValue B) {
3187     SDLoc DL(N);
3188     SDValue NewY = DAG.getNode(ISD::ADDCARRY, DL, Carry0->getVTList(), A, B, Z);
3189     Combiner.AddToWorklist(NewY.getNode());
3190     return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), X,
3191                        DAG.getConstant(0, DL, X.getValueType()),
3192                        NewY.getValue(1));
3193   };
3194 
3195   /**
3196    *      (uaddo A, B)
3197    *           |
3198    *          Sum
3199    *           |
3200    * (addcarry *, 0, Z)
3201    */
3202   if (Carry0.getOperand(0) == Carry1.getValue(0)) {
3203     return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1));
3204   }
3205 
3206   /**
3207    * (addcarry A, 0, Z)
3208    *         |
3209    *        Sum
3210    *         |
3211    *  (uaddo *, B)
3212    */
3213   if (Carry1.getOperand(0) == Carry0.getValue(0)) {
3214     return cancelDiamond(Carry0.getOperand(0), Carry1.getOperand(1));
3215   }
3216 
3217   if (Carry1.getOperand(1) == Carry0.getValue(0)) {
3218     return cancelDiamond(Carry1.getOperand(0), Carry0.getOperand(0));
3219   }
3220 
3221   return SDValue();
3222 }
3223 
3224 // If we are facing some sort of diamond carry/borrow in/out pattern try to
3225 // match patterns like:
3226 //
3227 //          (uaddo A, B)            CarryIn
3228 //            |  \                     |
3229 //            |   \                    |
3230 //    PartialSum   PartialCarryOutX   /
3231 //            |        |             /
3232 //            |    ____|____________/
3233 //            |   /    |
3234 //     (uaddo *, *)    \________
3235 //       |  \                   \
3236 //       |   \                   |
3237 //       |    PartialCarryOutY   |
3238 //       |        \              |
3239 //       |         \            /
3240 //   AddCarrySum    |    ______/
3241 //                  |   /
3242 //   CarryOut = (or *, *)
3243 //
3244 // And generate ADDCARRY (or SUBCARRY) with two result values:
3245 //
3246 //    {AddCarrySum, CarryOut} = (addcarry A, B, CarryIn)
3247 //
3248 // Our goal is to identify A, B, and CarryIn and produce ADDCARRY/SUBCARRY with
3249 // a single path for carry/borrow out propagation:
combineCarryDiamond(SelectionDAG & DAG,const TargetLowering & TLI,SDValue N0,SDValue N1,SDNode * N)3250 static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI,
3251                                    SDValue N0, SDValue N1, SDNode *N) {
3252   SDValue Carry0 = getAsCarry(TLI, N0);
3253   if (!Carry0)
3254     return SDValue();
3255   SDValue Carry1 = getAsCarry(TLI, N1);
3256   if (!Carry1)
3257     return SDValue();
3258 
3259   unsigned Opcode = Carry0.getOpcode();
3260   if (Opcode != Carry1.getOpcode())
3261     return SDValue();
3262   if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3263     return SDValue();
3264 
3265   // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3266   // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3267   if (Carry1.getNode()->isOperandOf(Carry0.getNode()))
3268     std::swap(Carry0, Carry1);
3269 
3270   // Check if nodes are connected in expected way.
3271   if (Carry1.getOperand(0) != Carry0.getValue(0) &&
3272       Carry1.getOperand(1) != Carry0.getValue(0))
3273     return SDValue();
3274 
3275   // The carry in value must be on the righthand side for subtraction.
3276   unsigned CarryInOperandNum =
3277       Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0;
3278   if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3279     return SDValue();
3280   SDValue CarryIn = Carry1.getOperand(CarryInOperandNum);
3281 
3282   unsigned NewOp = Opcode == ISD::UADDO ? ISD::ADDCARRY : ISD::SUBCARRY;
3283   if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType()))
3284     return SDValue();
3285 
3286   // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3287   // TODO: make getAsCarry() aware of how partial carries are merged.
3288   if (CarryIn.getOpcode() != ISD::ZERO_EXTEND)
3289     return SDValue();
3290   CarryIn = CarryIn.getOperand(0);
3291   if (CarryIn.getValueType() != MVT::i1)
3292     return SDValue();
3293 
3294   SDLoc DL(N);
3295   SDValue Merged =
3296       DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0),
3297                   Carry0.getOperand(1), CarryIn);
3298 
3299   // Please note that because we have proven that the result of the UADDO/USUBO
3300   // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3301   // therefore prove that if the first UADDO/USUBO overflows, the second
3302   // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3303   // maximum value.
3304   //
3305   //   0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3306   //   0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3307   //
3308   // This is important because it means that OR and XOR can be used to merge
3309   // carry flags; and that AND can return a constant zero.
3310   //
3311   // TODO: match other operations that can merge flags (ADD, etc)
3312   DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0));
3313   if (N->getOpcode() == ISD::AND)
3314     return DAG.getConstant(0, DL, MVT::i1);
3315   return Merged.getValue(1);
3316 }
3317 
visitADDCARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3318 SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
3319                                        SDNode *N) {
3320   // fold (addcarry (xor a, -1), b, c) -> (subcarry b, a, !c) and flip carry.
3321   if (isBitwiseNot(N0))
3322     if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) {
3323       SDLoc DL(N);
3324       SDValue Sub = DAG.getNode(ISD::SUBCARRY, DL, N->getVTList(), N1,
3325                                 N0.getOperand(0), NotC);
3326       return CombineTo(
3327           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3328     }
3329 
3330   // Iff the flag result is dead:
3331   // (addcarry (add|uaddo X, Y), 0, Carry) -> (addcarry X, Y, Carry)
3332   // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3333   // or the dependency between the instructions.
3334   if ((N0.getOpcode() == ISD::ADD ||
3335        (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3336         N0.getValue(1) != CarryIn)) &&
3337       isNullConstant(N1) && !N->hasAnyUseOfValue(1))
3338     return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(),
3339                        N0.getOperand(0), N0.getOperand(1), CarryIn);
3340 
3341   /**
3342    * When one of the addcarry argument is itself a carry, we may be facing
3343    * a diamond carry propagation. In which case we try to transform the DAG
3344    * to ensure linear carry propagation if that is possible.
3345    */
3346   if (auto Y = getAsCarry(TLI, N1)) {
3347     // Because both are carries, Y and Z can be swapped.
3348     if (auto R = combineADDCARRYDiamond(*this, DAG, N0, Y, CarryIn, N))
3349       return R;
3350     if (auto R = combineADDCARRYDiamond(*this, DAG, N0, CarryIn, Y, N))
3351       return R;
3352   }
3353 
3354   return SDValue();
3355 }
3356 
3357 // Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3358 // clamp/truncation if necessary.
getTruncatedUSUBSAT(EVT DstVT,EVT SrcVT,SDValue LHS,SDValue RHS,SelectionDAG & DAG,const SDLoc & DL)3359 static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3360                                    SDValue RHS, SelectionDAG &DAG,
3361                                    const SDLoc &DL) {
3362   assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3363          "Illegal truncation");
3364 
3365   if (DstVT == SrcVT)
3366     return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3367 
3368   // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3369   // clamping RHS.
3370   APInt UpperBits = APInt::getBitsSetFrom(SrcVT.getScalarSizeInBits(),
3371                                           DstVT.getScalarSizeInBits());
3372   if (!DAG.MaskedValueIsZero(LHS, UpperBits))
3373     return SDValue();
3374 
3375   SDValue SatLimit =
3376       DAG.getConstant(APInt::getLowBitsSet(SrcVT.getScalarSizeInBits(),
3377                                            DstVT.getScalarSizeInBits()),
3378                       DL, SrcVT);
3379   RHS = DAG.getNode(ISD::UMIN, DL, SrcVT, RHS, SatLimit);
3380   RHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, RHS);
3381   LHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, LHS);
3382   return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3383 }
3384 
3385 // Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3386 // usubsat(a,b), optionally as a truncated type.
foldSubToUSubSat(EVT DstVT,SDNode * N)3387 SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N) {
3388   if (N->getOpcode() != ISD::SUB ||
3389       !(!LegalOperations || hasOperation(ISD::USUBSAT, DstVT)))
3390     return SDValue();
3391 
3392   EVT SubVT = N->getValueType(0);
3393   SDValue Op0 = N->getOperand(0);
3394   SDValue Op1 = N->getOperand(1);
3395 
3396   // Try to find umax(a,b) - b or a - umin(a,b) patterns
3397   // they may be converted to usubsat(a,b).
3398   if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
3399     SDValue MaxLHS = Op0.getOperand(0);
3400     SDValue MaxRHS = Op0.getOperand(1);
3401     if (MaxLHS == Op1)
3402       return getTruncatedUSUBSAT(DstVT, SubVT, MaxRHS, Op1, DAG, SDLoc(N));
3403     if (MaxRHS == Op1)
3404       return getTruncatedUSUBSAT(DstVT, SubVT, MaxLHS, Op1, DAG, SDLoc(N));
3405   }
3406 
3407   if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
3408     SDValue MinLHS = Op1.getOperand(0);
3409     SDValue MinRHS = Op1.getOperand(1);
3410     if (MinLHS == Op0)
3411       return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinRHS, DAG, SDLoc(N));
3412     if (MinRHS == Op0)
3413       return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinLHS, DAG, SDLoc(N));
3414   }
3415 
3416   // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
3417   if (Op1.getOpcode() == ISD::TRUNCATE &&
3418       Op1.getOperand(0).getOpcode() == ISD::UMIN &&
3419       Op1.getOperand(0).hasOneUse()) {
3420     SDValue MinLHS = Op1.getOperand(0).getOperand(0);
3421     SDValue MinRHS = Op1.getOperand(0).getOperand(1);
3422     if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(0) == Op0)
3423       return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinLHS, MinRHS,
3424                                  DAG, SDLoc(N));
3425     if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(0) == Op0)
3426       return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinRHS, MinLHS,
3427                                  DAG, SDLoc(N));
3428   }
3429 
3430   return SDValue();
3431 }
3432 
3433 // Since it may not be valid to emit a fold to zero for vector initializers
3434 // check if we can before folding.
tryFoldToZero(const SDLoc & DL,const TargetLowering & TLI,EVT VT,SelectionDAG & DAG,bool LegalOperations)3435 static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
3436                              SelectionDAG &DAG, bool LegalOperations) {
3437   if (!VT.isVector())
3438     return DAG.getConstant(0, DL, VT);
3439   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
3440     return DAG.getConstant(0, DL, VT);
3441   return SDValue();
3442 }
3443 
visitSUB(SDNode * N)3444 SDValue DAGCombiner::visitSUB(SDNode *N) {
3445   SDValue N0 = N->getOperand(0);
3446   SDValue N1 = N->getOperand(1);
3447   EVT VT = N0.getValueType();
3448   SDLoc DL(N);
3449 
3450   auto PeekThroughFreeze = [](SDValue N) {
3451     if (N->getOpcode() == ISD::FREEZE && N.hasOneUse())
3452       return N->getOperand(0);
3453     return N;
3454   };
3455 
3456   // fold (sub x, x) -> 0
3457   // FIXME: Refactor this and xor and other similar operations together.
3458   if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1))
3459     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
3460 
3461   // fold (sub c1, c2) -> c3
3462   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1}))
3463     return C;
3464 
3465   // fold vector ops
3466   if (VT.isVector()) {
3467     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3468       return FoldedVOp;
3469 
3470     // fold (sub x, 0) -> x, vector edition
3471     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3472       return N0;
3473   }
3474 
3475   if (SDValue NewSel = foldBinOpIntoSelect(N))
3476     return NewSel;
3477 
3478   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3479 
3480   // fold (sub x, c) -> (add x, -c)
3481   if (N1C) {
3482     return DAG.getNode(ISD::ADD, DL, VT, N0,
3483                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3484   }
3485 
3486   if (isNullOrNullSplat(N0)) {
3487     unsigned BitWidth = VT.getScalarSizeInBits();
3488     // Right-shifting everything out but the sign bit followed by negation is
3489     // the same as flipping arithmetic/logical shift type without the negation:
3490     // -(X >>u 31) -> (X >>s 31)
3491     // -(X >>s 31) -> (X >>u 31)
3492     if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
3493       ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
3494       if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
3495         auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
3496         if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
3497           return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
3498       }
3499     }
3500 
3501     // 0 - X --> 0 if the sub is NUW.
3502     if (N->getFlags().hasNoUnsignedWrap())
3503       return N0;
3504 
3505     if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
3506       // N1 is either 0 or the minimum signed value. If the sub is NSW, then
3507       // N1 must be 0 because negating the minimum signed value is undefined.
3508       if (N->getFlags().hasNoSignedWrap())
3509         return N0;
3510 
3511       // 0 - X --> X if X is 0 or the minimum signed value.
3512       return N1;
3513     }
3514 
3515     // Convert 0 - abs(x).
3516     if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
3517         !TLI.isOperationLegalOrCustom(ISD::ABS, VT))
3518       if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true))
3519         return Result;
3520 
3521     // Fold neg(splat(neg(x)) -> splat(x)
3522     if (VT.isVector()) {
3523       SDValue N1S = DAG.getSplatValue(N1, true);
3524       if (N1S && N1S.getOpcode() == ISD::SUB &&
3525           isNullConstant(N1S.getOperand(0)))
3526         return DAG.getSplat(VT, DL, N1S.getOperand(1));
3527     }
3528   }
3529 
3530   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
3531   if (isAllOnesOrAllOnesSplat(N0))
3532     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3533 
3534   // fold (A - (0-B)) -> A+B
3535   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
3536     return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
3537 
3538   // fold A-(A-B) -> B
3539   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
3540     return N1.getOperand(1);
3541 
3542   // fold (A+B)-A -> B
3543   if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
3544     return N0.getOperand(1);
3545 
3546   // fold (A+B)-B -> A
3547   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
3548     return N0.getOperand(0);
3549 
3550   // fold (A+C1)-C2 -> A+(C1-C2)
3551   if (N0.getOpcode() == ISD::ADD) {
3552     SDValue N01 = N0.getOperand(1);
3553     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N01, N1}))
3554       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC);
3555   }
3556 
3557   // fold C2-(A+C1) -> (C2-C1)-A
3558   if (N1.getOpcode() == ISD::ADD) {
3559     SDValue N11 = N1.getOperand(1);
3560     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11}))
3561       return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
3562   }
3563 
3564   // fold (A-C1)-C2 -> A-(C1+C2)
3565   if (N0.getOpcode() == ISD::SUB) {
3566     SDValue N01 = N0.getOperand(1);
3567     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N01, N1}))
3568       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC);
3569   }
3570 
3571   // fold (c1-A)-c2 -> (c1-c2)-A
3572   if (N0.getOpcode() == ISD::SUB) {
3573     SDValue N00 = N0.getOperand(0);
3574     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N00, N1}))
3575       return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
3576   }
3577 
3578   // fold ((A+(B+or-C))-B) -> A+or-C
3579   if (N0.getOpcode() == ISD::ADD &&
3580       (N0.getOperand(1).getOpcode() == ISD::SUB ||
3581        N0.getOperand(1).getOpcode() == ISD::ADD) &&
3582       N0.getOperand(1).getOperand(0) == N1)
3583     return DAG.getNode(N0.getOperand(1).getOpcode(), DL, VT, N0.getOperand(0),
3584                        N0.getOperand(1).getOperand(1));
3585 
3586   // fold ((A+(C+B))-B) -> A+C
3587   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1).getOpcode() == ISD::ADD &&
3588       N0.getOperand(1).getOperand(1) == N1)
3589     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0),
3590                        N0.getOperand(1).getOperand(0));
3591 
3592   // fold ((A-(B-C))-C) -> A-B
3593   if (N0.getOpcode() == ISD::SUB && N0.getOperand(1).getOpcode() == ISD::SUB &&
3594       N0.getOperand(1).getOperand(1) == N1)
3595     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
3596                        N0.getOperand(1).getOperand(0));
3597 
3598   // fold (A-(B-C)) -> A+(C-B)
3599   if (N1.getOpcode() == ISD::SUB && N1.hasOneUse())
3600     return DAG.getNode(ISD::ADD, DL, VT, N0,
3601                        DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(1),
3602                                    N1.getOperand(0)));
3603 
3604   // A - (A & B)  ->  A & (~B)
3605   if (N1.getOpcode() == ISD::AND) {
3606     SDValue A = N1.getOperand(0);
3607     SDValue B = N1.getOperand(1);
3608     if (A != N0)
3609       std::swap(A, B);
3610     if (A == N0 &&
3611         (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true))) {
3612       SDValue InvB =
3613           DAG.getNode(ISD::XOR, DL, VT, B, DAG.getAllOnesConstant(DL, VT));
3614       return DAG.getNode(ISD::AND, DL, VT, A, InvB);
3615     }
3616   }
3617 
3618   // fold (X - (-Y * Z)) -> (X + (Y * Z))
3619   if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) {
3620     if (N1.getOperand(0).getOpcode() == ISD::SUB &&
3621         isNullOrNullSplat(N1.getOperand(0).getOperand(0))) {
3622       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3623                                 N1.getOperand(0).getOperand(1),
3624                                 N1.getOperand(1));
3625       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3626     }
3627     if (N1.getOperand(1).getOpcode() == ISD::SUB &&
3628         isNullOrNullSplat(N1.getOperand(1).getOperand(0))) {
3629       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3630                                 N1.getOperand(0),
3631                                 N1.getOperand(1).getOperand(1));
3632       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3633     }
3634   }
3635 
3636   // If either operand of a sub is undef, the result is undef
3637   if (N0.isUndef())
3638     return N0;
3639   if (N1.isUndef())
3640     return N1;
3641 
3642   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
3643     return V;
3644 
3645   if (SDValue V = foldAddSubOfSignBit(N, DAG))
3646     return V;
3647 
3648   if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, SDLoc(N)))
3649     return V;
3650 
3651   if (SDValue V = foldSubToUSubSat(VT, N))
3652     return V;
3653 
3654   // (x - y) - 1  ->  add (xor y, -1), x
3655   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() && isOneOrOneSplat(N1)) {
3656     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1),
3657                               DAG.getAllOnesConstant(DL, VT));
3658     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
3659   }
3660 
3661   // Look for:
3662   //   sub y, (xor x, -1)
3663   // And if the target does not like this form then turn into:
3664   //   add (add x, y), 1
3665   if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(N1)) {
3666     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(0));
3667     return DAG.getNode(ISD::ADD, DL, VT, Add, DAG.getConstant(1, DL, VT));
3668   }
3669 
3670   // Hoist one-use addition by non-opaque constant:
3671   //   (x + C) - y  ->  (x - y) + C
3672   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
3673       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3674     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3675     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
3676   }
3677   // y - (x + C)  ->  (y - x) - C
3678   if (N1.getOpcode() == ISD::ADD && N1.hasOneUse() &&
3679       isConstantOrConstantVector(N1.getOperand(1), /*NoOpaques=*/true)) {
3680     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(0));
3681     return DAG.getNode(ISD::SUB, DL, VT, Sub, N1.getOperand(1));
3682   }
3683   // (x - C) - y  ->  (x - y) - C
3684   // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3685   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
3686       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3687     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3688     return DAG.getNode(ISD::SUB, DL, VT, Sub, N0.getOperand(1));
3689   }
3690   // (C - x) - y  ->  C - (x + y)
3691   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
3692       isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
3693     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), N1);
3694     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), Add);
3695   }
3696 
3697   // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
3698   // rather than 'sub 0/1' (the sext should get folded).
3699   // sub X, (zext i1 Y) --> add X, (sext i1 Y)
3700   if (N1.getOpcode() == ISD::ZERO_EXTEND &&
3701       N1.getOperand(0).getScalarValueSizeInBits() == 1 &&
3702       TLI.getBooleanContents(VT) ==
3703           TargetLowering::ZeroOrNegativeOneBooleanContent) {
3704     SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N1.getOperand(0));
3705     return DAG.getNode(ISD::ADD, DL, VT, N0, SExt);
3706   }
3707 
3708   // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X)
3709   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
3710     if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) {
3711       SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1);
3712       SDValue S0 = N1.getOperand(0);
3713       if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0))
3714         if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1)))
3715           if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
3716             return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0);
3717     }
3718   }
3719 
3720   // If the relocation model supports it, consider symbol offsets.
3721   if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
3722     if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
3723       // fold (sub Sym, c) -> Sym-c
3724       if (N1C && GA->getOpcode() == ISD::GlobalAddress)
3725         return DAG.getGlobalAddress(GA->getGlobal(), SDLoc(N1C), VT,
3726                                     GA->getOffset() -
3727                                         (uint64_t)N1C->getSExtValue());
3728       // fold (sub Sym+c1, Sym+c2) -> c1-c2
3729       if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
3730         if (GA->getGlobal() == GB->getGlobal())
3731           return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
3732                                  DL, VT);
3733     }
3734 
3735   // sub X, (sextinreg Y i1) -> add X, (and Y 1)
3736   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3737     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
3738     if (TN->getVT() == MVT::i1) {
3739       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
3740                                  DAG.getConstant(1, DL, VT));
3741       return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
3742     }
3743   }
3744 
3745   // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
3746   if (N1.getOpcode() == ISD::VSCALE && N1.hasOneUse()) {
3747     const APInt &IntVal = N1.getConstantOperandAPInt(0);
3748     return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -IntVal));
3749   }
3750 
3751   // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
3752   if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
3753     APInt NewStep = -N1.getConstantOperandAPInt(0);
3754     return DAG.getNode(ISD::ADD, DL, VT, N0,
3755                        DAG.getStepVector(DL, VT, NewStep));
3756   }
3757 
3758   // Prefer an add for more folding potential and possibly better codegen:
3759   // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
3760   if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
3761     SDValue ShAmt = N1.getOperand(1);
3762     ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
3763     if (ShAmtC &&
3764         ShAmtC->getAPIntValue() == (N1.getScalarValueSizeInBits() - 1)) {
3765       SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
3766       return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
3767     }
3768   }
3769 
3770   // As with the previous fold, prefer add for more folding potential.
3771   // Subtracting SMIN/0 is the same as adding SMIN/0:
3772   // N0 - (X << BW-1) --> N0 + (X << BW-1)
3773   if (N1.getOpcode() == ISD::SHL) {
3774     ConstantSDNode *ShlC = isConstOrConstSplat(N1.getOperand(1));
3775     if (ShlC && ShlC->getAPIntValue() == VT.getScalarSizeInBits() - 1)
3776       return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
3777   }
3778 
3779   // (sub (subcarry X, 0, Carry), Y) -> (subcarry X, Y, Carry)
3780   if (N0.getOpcode() == ISD::SUBCARRY && isNullConstant(N0.getOperand(1)) &&
3781       N0.getResNo() == 0 && N0.hasOneUse())
3782     return DAG.getNode(ISD::SUBCARRY, DL, N0->getVTList(),
3783                        N0.getOperand(0), N1, N0.getOperand(2));
3784 
3785   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) {
3786     // (sub Carry, X)  ->  (addcarry (sub 0, X), 0, Carry)
3787     if (SDValue Carry = getAsCarry(TLI, N0)) {
3788       SDValue X = N1;
3789       SDValue Zero = DAG.getConstant(0, DL, VT);
3790       SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X);
3791       return DAG.getNode(ISD::ADDCARRY, DL,
3792                          DAG.getVTList(VT, Carry.getValueType()), NegX, Zero,
3793                          Carry);
3794     }
3795   }
3796 
3797   // If there's no chance of borrowing from adjacent bits, then sub is xor:
3798   // sub C0, X --> xor X, C0
3799   if (ConstantSDNode *C0 = isConstOrConstSplat(N0)) {
3800     if (!C0->isOpaque()) {
3801       const APInt &C0Val = C0->getAPIntValue();
3802       const APInt &MaybeOnes = ~DAG.computeKnownBits(N1).Zero;
3803       if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
3804         return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3805     }
3806   }
3807 
3808   // max(a,b) - min(a,b) --> abd(a,b)
3809   auto MatchSubMaxMin = [&](unsigned Max, unsigned Min, unsigned Abd) {
3810     if (N0.getOpcode() != Max || N1.getOpcode() != Min)
3811       return SDValue();
3812     if ((N0.getOperand(0) != N1.getOperand(0) ||
3813          N0.getOperand(1) != N1.getOperand(1)) &&
3814         (N0.getOperand(0) != N1.getOperand(1) ||
3815          N0.getOperand(1) != N1.getOperand(0)))
3816       return SDValue();
3817     if (!TLI.isOperationLegalOrCustom(Abd, VT))
3818       return SDValue();
3819     return DAG.getNode(Abd, DL, VT, N0.getOperand(0), N0.getOperand(1));
3820   };
3821   if (SDValue R = MatchSubMaxMin(ISD::SMAX, ISD::SMIN, ISD::ABDS))
3822     return R;
3823   if (SDValue R = MatchSubMaxMin(ISD::UMAX, ISD::UMIN, ISD::ABDU))
3824     return R;
3825 
3826   return SDValue();
3827 }
3828 
visitSUBSAT(SDNode * N)3829 SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
3830   SDValue N0 = N->getOperand(0);
3831   SDValue N1 = N->getOperand(1);
3832   EVT VT = N0.getValueType();
3833   SDLoc DL(N);
3834 
3835   // fold (sub_sat x, undef) -> 0
3836   if (N0.isUndef() || N1.isUndef())
3837     return DAG.getConstant(0, DL, VT);
3838 
3839   // fold (sub_sat x, x) -> 0
3840   if (N0 == N1)
3841     return DAG.getConstant(0, DL, VT);
3842 
3843   // fold (sub_sat c1, c2) -> c3
3844   if (SDValue C = DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1}))
3845     return C;
3846 
3847   // fold vector ops
3848   if (VT.isVector()) {
3849     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3850       return FoldedVOp;
3851 
3852     // fold (sub_sat x, 0) -> x, vector edition
3853     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3854       return N0;
3855   }
3856 
3857   // fold (sub_sat x, 0) -> x
3858   if (isNullConstant(N1))
3859     return N0;
3860 
3861   return SDValue();
3862 }
3863 
visitSUBC(SDNode * N)3864 SDValue DAGCombiner::visitSUBC(SDNode *N) {
3865   SDValue N0 = N->getOperand(0);
3866   SDValue N1 = N->getOperand(1);
3867   EVT VT = N0.getValueType();
3868   SDLoc DL(N);
3869 
3870   // If the flag result is dead, turn this into an SUB.
3871   if (!N->hasAnyUseOfValue(1))
3872     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3873                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3874 
3875   // fold (subc x, x) -> 0 + no borrow
3876   if (N0 == N1)
3877     return CombineTo(N, DAG.getConstant(0, DL, VT),
3878                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3879 
3880   // fold (subc x, 0) -> x + no borrow
3881   if (isNullConstant(N1))
3882     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3883 
3884   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3885   if (isAllOnesConstant(N0))
3886     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3887                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3888 
3889   return SDValue();
3890 }
3891 
visitSUBO(SDNode * N)3892 SDValue DAGCombiner::visitSUBO(SDNode *N) {
3893   SDValue N0 = N->getOperand(0);
3894   SDValue N1 = N->getOperand(1);
3895   EVT VT = N0.getValueType();
3896   bool IsSigned = (ISD::SSUBO == N->getOpcode());
3897 
3898   EVT CarryVT = N->getValueType(1);
3899   SDLoc DL(N);
3900 
3901   // If the flag result is dead, turn this into an SUB.
3902   if (!N->hasAnyUseOfValue(1))
3903     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3904                      DAG.getUNDEF(CarryVT));
3905 
3906   // fold (subo x, x) -> 0 + no borrow
3907   if (N0 == N1)
3908     return CombineTo(N, DAG.getConstant(0, DL, VT),
3909                      DAG.getConstant(0, DL, CarryVT));
3910 
3911   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3912 
3913   // fold (subox, c) -> (addo x, -c)
3914   if (IsSigned && N1C && !N1C->getAPIntValue().isMinSignedValue()) {
3915     return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0,
3916                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3917   }
3918 
3919   // fold (subo x, 0) -> x + no borrow
3920   if (isNullOrNullSplat(N1))
3921     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
3922 
3923   // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3924   if (!IsSigned && isAllOnesOrAllOnesSplat(N0))
3925     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3926                      DAG.getConstant(0, DL, CarryVT));
3927 
3928   return SDValue();
3929 }
3930 
visitSUBE(SDNode * N)3931 SDValue DAGCombiner::visitSUBE(SDNode *N) {
3932   SDValue N0 = N->getOperand(0);
3933   SDValue N1 = N->getOperand(1);
3934   SDValue CarryIn = N->getOperand(2);
3935 
3936   // fold (sube x, y, false) -> (subc x, y)
3937   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3938     return DAG.getNode(ISD::SUBC, SDLoc(N), N->getVTList(), N0, N1);
3939 
3940   return SDValue();
3941 }
3942 
visitSUBCARRY(SDNode * N)3943 SDValue DAGCombiner::visitSUBCARRY(SDNode *N) {
3944   SDValue N0 = N->getOperand(0);
3945   SDValue N1 = N->getOperand(1);
3946   SDValue CarryIn = N->getOperand(2);
3947 
3948   // fold (subcarry x, y, false) -> (usubo x, y)
3949   if (isNullConstant(CarryIn)) {
3950     if (!LegalOperations ||
3951         TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
3952       return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
3953   }
3954 
3955   return SDValue();
3956 }
3957 
visitSSUBO_CARRY(SDNode * N)3958 SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
3959   SDValue N0 = N->getOperand(0);
3960   SDValue N1 = N->getOperand(1);
3961   SDValue CarryIn = N->getOperand(2);
3962 
3963   // fold (ssubo_carry x, y, false) -> (ssubo x, y)
3964   if (isNullConstant(CarryIn)) {
3965     if (!LegalOperations ||
3966         TLI.isOperationLegalOrCustom(ISD::SSUBO, N->getValueType(0)))
3967       return DAG.getNode(ISD::SSUBO, SDLoc(N), N->getVTList(), N0, N1);
3968   }
3969 
3970   return SDValue();
3971 }
3972 
3973 // Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
3974 // UMULFIXSAT here.
visitMULFIX(SDNode * N)3975 SDValue DAGCombiner::visitMULFIX(SDNode *N) {
3976   SDValue N0 = N->getOperand(0);
3977   SDValue N1 = N->getOperand(1);
3978   SDValue Scale = N->getOperand(2);
3979   EVT VT = N0.getValueType();
3980 
3981   // fold (mulfix x, undef, scale) -> 0
3982   if (N0.isUndef() || N1.isUndef())
3983     return DAG.getConstant(0, SDLoc(N), VT);
3984 
3985   // Canonicalize constant to RHS (vector doesn't have to splat)
3986   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3987      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3988     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0, Scale);
3989 
3990   // fold (mulfix x, 0, scale) -> 0
3991   if (isNullConstant(N1))
3992     return DAG.getConstant(0, SDLoc(N), VT);
3993 
3994   return SDValue();
3995 }
3996 
visitMUL(SDNode * N)3997 SDValue DAGCombiner::visitMUL(SDNode *N) {
3998   SDValue N0 = N->getOperand(0);
3999   SDValue N1 = N->getOperand(1);
4000   EVT VT = N0.getValueType();
4001   SDLoc DL(N);
4002 
4003   // fold (mul x, undef) -> 0
4004   if (N0.isUndef() || N1.isUndef())
4005     return DAG.getConstant(0, DL, VT);
4006 
4007   // fold (mul c1, c2) -> c1*c2
4008   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}))
4009     return C;
4010 
4011   // canonicalize constant to RHS (vector doesn't have to splat)
4012   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4013       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4014     return DAG.getNode(ISD::MUL, DL, VT, N1, N0);
4015 
4016   bool N1IsConst = false;
4017   bool N1IsOpaqueConst = false;
4018   APInt ConstValue1;
4019 
4020   // fold vector ops
4021   if (VT.isVector()) {
4022     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4023       return FoldedVOp;
4024 
4025     N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
4026     assert((!N1IsConst ||
4027             ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) &&
4028            "Splat APInt should be element width");
4029   } else {
4030     N1IsConst = isa<ConstantSDNode>(N1);
4031     if (N1IsConst) {
4032       ConstValue1 = cast<ConstantSDNode>(N1)->getAPIntValue();
4033       N1IsOpaqueConst = cast<ConstantSDNode>(N1)->isOpaque();
4034     }
4035   }
4036 
4037   // fold (mul x, 0) -> 0
4038   if (N1IsConst && ConstValue1.isZero())
4039     return N1;
4040 
4041   // fold (mul x, 1) -> x
4042   if (N1IsConst && ConstValue1.isOne())
4043     return N0;
4044 
4045   if (SDValue NewSel = foldBinOpIntoSelect(N))
4046     return NewSel;
4047 
4048   // fold (mul x, -1) -> 0-x
4049   if (N1IsConst && ConstValue1.isAllOnes())
4050     return DAG.getNegative(N0, DL, VT);
4051 
4052   // fold (mul x, (1 << c)) -> x << c
4053   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4054       DAG.isKnownToBeAPowerOfTwo(N1) &&
4055       (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4056     SDValue LogBase2 = BuildLogBase2(N1, DL);
4057     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4058     SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4059     return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
4060   }
4061 
4062   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4063   if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4064     unsigned Log2Val = (-ConstValue1).logBase2();
4065     // FIXME: If the input is something that is easily negated (e.g. a
4066     // single-use add), we should put the negate there.
4067     return DAG.getNode(ISD::SUB, DL, VT,
4068                        DAG.getConstant(0, DL, VT),
4069                        DAG.getNode(ISD::SHL, DL, VT, N0,
4070                             DAG.getConstant(Log2Val, DL,
4071                                       getShiftAmountTy(N0.getValueType()))));
4072   }
4073 
4074   // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4075   // hi result is in use in case we hit this mid-legalization.
4076   for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4077     if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4078       SDVTList LoHiVT = DAG.getVTList(VT, VT);
4079       // TODO: Can we match commutable operands with getNodeIfExists?
4080       if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4081         if (LoHi->hasAnyUseOfValue(1))
4082           return SDValue(LoHi, 0);
4083       if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4084         if (LoHi->hasAnyUseOfValue(1))
4085           return SDValue(LoHi, 0);
4086     }
4087   }
4088 
4089   // Try to transform:
4090   // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4091   // mul x, (2^N + 1) --> add (shl x, N), x
4092   // mul x, (2^N - 1) --> sub (shl x, N), x
4093   // Examples: x * 33 --> (x << 5) + x
4094   //           x * 15 --> (x << 4) - x
4095   //           x * -33 --> -((x << 5) + x)
4096   //           x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4097   // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4098   // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4099   // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4100   // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4101   //           x * 0xf800 --> (x << 16) - (x << 11)
4102   //           x * -0x8800 --> -((x << 15) + (x << 11))
4103   //           x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4104   if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4105     // TODO: We could handle more general decomposition of any constant by
4106     //       having the target set a limit on number of ops and making a
4107     //       callback to determine that sequence (similar to sqrt expansion).
4108     unsigned MathOp = ISD::DELETED_NODE;
4109     APInt MulC = ConstValue1.abs();
4110     // The constant `2` should be treated as (2^0 + 1).
4111     unsigned TZeros = MulC == 2 ? 0 : MulC.countTrailingZeros();
4112     MulC.lshrInPlace(TZeros);
4113     if ((MulC - 1).isPowerOf2())
4114       MathOp = ISD::ADD;
4115     else if ((MulC + 1).isPowerOf2())
4116       MathOp = ISD::SUB;
4117 
4118     if (MathOp != ISD::DELETED_NODE) {
4119       unsigned ShAmt =
4120           MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4121       ShAmt += TZeros;
4122       assert(ShAmt < VT.getScalarSizeInBits() &&
4123              "multiply-by-constant generated out of bounds shift");
4124       SDValue Shl =
4125           DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
4126       SDValue R =
4127           TZeros ? DAG.getNode(MathOp, DL, VT, Shl,
4128                                DAG.getNode(ISD::SHL, DL, VT, N0,
4129                                            DAG.getConstant(TZeros, DL, VT)))
4130                  : DAG.getNode(MathOp, DL, VT, Shl, N0);
4131       if (ConstValue1.isNegative())
4132         R = DAG.getNegative(R, DL, VT);
4133       return R;
4134     }
4135   }
4136 
4137   // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4138   if (N0.getOpcode() == ISD::SHL) {
4139     SDValue N01 = N0.getOperand(1);
4140     if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
4141       return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
4142   }
4143 
4144   // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4145   // use.
4146   {
4147     SDValue Sh, Y;
4148 
4149     // Check for both (mul (shl X, C), Y)  and  (mul Y, (shl X, C)).
4150     if (N0.getOpcode() == ISD::SHL &&
4151         isConstantOrConstantVector(N0.getOperand(1)) && N0->hasOneUse()) {
4152       Sh = N0; Y = N1;
4153     } else if (N1.getOpcode() == ISD::SHL &&
4154                isConstantOrConstantVector(N1.getOperand(1)) &&
4155                N1->hasOneUse()) {
4156       Sh = N1; Y = N0;
4157     }
4158 
4159     if (Sh.getNode()) {
4160       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4161       return DAG.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4162     }
4163   }
4164 
4165   // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4166   if (DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
4167       N0.getOpcode() == ISD::ADD &&
4168       DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
4169       isMulAddWithConstProfitable(N, N0, N1))
4170     return DAG.getNode(
4171         ISD::ADD, DL, VT,
4172         DAG.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4173         DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4174 
4175   // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4176   ConstantSDNode *NC1 = isConstOrConstSplat(N1);
4177   if (N0.getOpcode() == ISD::VSCALE && NC1) {
4178     const APInt &C0 = N0.getConstantOperandAPInt(0);
4179     const APInt &C1 = NC1->getAPIntValue();
4180     return DAG.getVScale(DL, VT, C0 * C1);
4181   }
4182 
4183   // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4184   APInt MulVal;
4185   if (N0.getOpcode() == ISD::STEP_VECTOR &&
4186       ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
4187     const APInt &C0 = N0.getConstantOperandAPInt(0);
4188     APInt NewStep = C0 * MulVal;
4189     return DAG.getStepVector(DL, VT, NewStep);
4190   }
4191 
4192   // Fold ((mul x, 0/undef) -> 0,
4193   //       (mul x, 1) -> x) -> x)
4194   // -> and(x, mask)
4195   // We can replace vectors with '0' and '1' factors with a clearing mask.
4196   if (VT.isFixedLengthVector()) {
4197     unsigned NumElts = VT.getVectorNumElements();
4198     SmallBitVector ClearMask;
4199     ClearMask.reserve(NumElts);
4200     auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4201       if (!V || V->isZero()) {
4202         ClearMask.push_back(true);
4203         return true;
4204       }
4205       ClearMask.push_back(false);
4206       return V->isOne();
4207     };
4208     if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::AND, VT)) &&
4209         ISD::matchUnaryPredicate(N1, IsClearMask, /*AllowUndefs*/ true)) {
4210       assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4211       EVT LegalSVT = N1.getOperand(0).getValueType();
4212       SDValue Zero = DAG.getConstant(0, DL, LegalSVT);
4213       SDValue AllOnes = DAG.getAllOnesConstant(DL, LegalSVT);
4214       SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
4215       for (unsigned I = 0; I != NumElts; ++I)
4216         if (ClearMask[I])
4217           Mask[I] = Zero;
4218       return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getBuildVector(VT, DL, Mask));
4219     }
4220   }
4221 
4222   // reassociate mul
4223   if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4224     return RMUL;
4225 
4226   // Simplify the operands using demanded-bits information.
4227   if (SimplifyDemandedBits(SDValue(N, 0)))
4228     return SDValue(N, 0);
4229 
4230   return SDValue();
4231 }
4232 
4233 /// Return true if divmod libcall is available.
isDivRemLibcallAvailable(SDNode * Node,bool isSigned,const TargetLowering & TLI)4234 static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
4235                                      const TargetLowering &TLI) {
4236   RTLIB::Libcall LC;
4237   EVT NodeType = Node->getValueType(0);
4238   if (!NodeType.isSimple())
4239     return false;
4240   switch (NodeType.getSimpleVT().SimpleTy) {
4241   default: return false; // No libcall for vector types.
4242   case MVT::i8:   LC= isSigned ? RTLIB::SDIVREM_I8  : RTLIB::UDIVREM_I8;  break;
4243   case MVT::i16:  LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
4244   case MVT::i32:  LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
4245   case MVT::i64:  LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
4246   case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
4247   }
4248 
4249   return TLI.getLibcallName(LC) != nullptr;
4250 }
4251 
4252 /// Issue divrem if both quotient and remainder are needed.
useDivRem(SDNode * Node)4253 SDValue DAGCombiner::useDivRem(SDNode *Node) {
4254   if (Node->use_empty())
4255     return SDValue(); // This is a dead node, leave it alone.
4256 
4257   unsigned Opcode = Node->getOpcode();
4258   bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
4259   unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
4260 
4261   // DivMod lib calls can still work on non-legal types if using lib-calls.
4262   EVT VT = Node->getValueType(0);
4263   if (VT.isVector() || !VT.isInteger())
4264     return SDValue();
4265 
4266   if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
4267     return SDValue();
4268 
4269   // If DIVREM is going to get expanded into a libcall,
4270   // but there is no libcall available, then don't combine.
4271   if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
4272       !isDivRemLibcallAvailable(Node, isSigned, TLI))
4273     return SDValue();
4274 
4275   // If div is legal, it's better to do the normal expansion
4276   unsigned OtherOpcode = 0;
4277   if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
4278     OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
4279     if (TLI.isOperationLegalOrCustom(Opcode, VT))
4280       return SDValue();
4281   } else {
4282     OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4283     if (TLI.isOperationLegalOrCustom(OtherOpcode, VT))
4284       return SDValue();
4285   }
4286 
4287   SDValue Op0 = Node->getOperand(0);
4288   SDValue Op1 = Node->getOperand(1);
4289   SDValue combined;
4290   for (SDNode *User : Op0->uses()) {
4291     if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
4292         User->use_empty())
4293       continue;
4294     // Convert the other matching node(s), too;
4295     // otherwise, the DIVREM may get target-legalized into something
4296     // target-specific that we won't be able to recognize.
4297     unsigned UserOpc = User->getOpcode();
4298     if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
4299         User->getOperand(0) == Op0 &&
4300         User->getOperand(1) == Op1) {
4301       if (!combined) {
4302         if (UserOpc == OtherOpcode) {
4303           SDVTList VTs = DAG.getVTList(VT, VT);
4304           combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1);
4305         } else if (UserOpc == DivRemOpc) {
4306           combined = SDValue(User, 0);
4307         } else {
4308           assert(UserOpc == Opcode);
4309           continue;
4310         }
4311       }
4312       if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
4313         CombineTo(User, combined);
4314       else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
4315         CombineTo(User, combined.getValue(1));
4316     }
4317   }
4318   return combined;
4319 }
4320 
simplifyDivRem(SDNode * N,SelectionDAG & DAG)4321 static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
4322   SDValue N0 = N->getOperand(0);
4323   SDValue N1 = N->getOperand(1);
4324   EVT VT = N->getValueType(0);
4325   SDLoc DL(N);
4326 
4327   unsigned Opc = N->getOpcode();
4328   bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
4329   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4330 
4331   // X / undef -> undef
4332   // X % undef -> undef
4333   // X / 0 -> undef
4334   // X % 0 -> undef
4335   // NOTE: This includes vectors where any divisor element is zero/undef.
4336   if (DAG.isUndef(Opc, {N0, N1}))
4337     return DAG.getUNDEF(VT);
4338 
4339   // undef / X -> 0
4340   // undef % X -> 0
4341   if (N0.isUndef())
4342     return DAG.getConstant(0, DL, VT);
4343 
4344   // 0 / X -> 0
4345   // 0 % X -> 0
4346   ConstantSDNode *N0C = isConstOrConstSplat(N0);
4347   if (N0C && N0C->isZero())
4348     return N0;
4349 
4350   // X / X -> 1
4351   // X % X -> 0
4352   if (N0 == N1)
4353     return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
4354 
4355   // X / 1 -> X
4356   // X % 1 -> 0
4357   // If this is a boolean op (single-bit element type), we can't have
4358   // division-by-zero or remainder-by-zero, so assume the divisor is 1.
4359   // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
4360   // it's a 1.
4361   if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
4362     return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
4363 
4364   return SDValue();
4365 }
4366 
visitSDIV(SDNode * N)4367 SDValue DAGCombiner::visitSDIV(SDNode *N) {
4368   SDValue N0 = N->getOperand(0);
4369   SDValue N1 = N->getOperand(1);
4370   EVT VT = N->getValueType(0);
4371   EVT CCVT = getSetCCResultType(VT);
4372   SDLoc DL(N);
4373 
4374   // fold (sdiv c1, c2) -> c1/c2
4375   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1}))
4376     return C;
4377 
4378   // fold vector ops
4379   if (VT.isVector())
4380     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4381       return FoldedVOp;
4382 
4383   // fold (sdiv X, -1) -> 0-X
4384   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4385   if (N1C && N1C->isAllOnes())
4386     return DAG.getNegative(N0, DL, VT);
4387 
4388   // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
4389   if (N1C && N1C->getAPIntValue().isMinSignedValue())
4390     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4391                          DAG.getConstant(1, DL, VT),
4392                          DAG.getConstant(0, DL, VT));
4393 
4394   if (SDValue V = simplifyDivRem(N, DAG))
4395     return V;
4396 
4397   if (SDValue NewSel = foldBinOpIntoSelect(N))
4398     return NewSel;
4399 
4400   // If we know the sign bits of both operands are zero, strength reduce to a
4401   // udiv instead.  Handles (X&15) /s 4 -> X&15 >> 2
4402   if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4403     return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
4404 
4405   if (SDValue V = visitSDIVLike(N0, N1, N)) {
4406     // If the corresponding remainder node exists, update its users with
4407     // (Dividend - (Quotient * Divisor).
4408     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
4409                                               { N0, N1 })) {
4410       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4411       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4412       AddToWorklist(Mul.getNode());
4413       AddToWorklist(Sub.getNode());
4414       CombineTo(RemNode, Sub);
4415     }
4416     return V;
4417   }
4418 
4419   // sdiv, srem -> sdivrem
4420   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4421   // true.  Otherwise, we break the simplification logic in visitREM().
4422   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4423   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4424     if (SDValue DivRem = useDivRem(N))
4425         return DivRem;
4426 
4427   return SDValue();
4428 }
4429 
isDivisorPowerOfTwo(SDValue Divisor)4430 static bool isDivisorPowerOfTwo(SDValue Divisor) {
4431   // Helper for determining whether a value is a power-2 constant scalar or a
4432   // vector of such elements.
4433   auto IsPowerOfTwo = [](ConstantSDNode *C) {
4434     if (C->isZero() || C->isOpaque())
4435       return false;
4436     if (C->getAPIntValue().isPowerOf2())
4437       return true;
4438     if (C->getAPIntValue().isNegatedPowerOf2())
4439       return true;
4440     return false;
4441   };
4442 
4443   return ISD::matchUnaryPredicate(Divisor, IsPowerOfTwo);
4444 }
4445 
visitSDIVLike(SDValue N0,SDValue N1,SDNode * N)4446 SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4447   SDLoc DL(N);
4448   EVT VT = N->getValueType(0);
4449   EVT CCVT = getSetCCResultType(VT);
4450   unsigned BitWidth = VT.getScalarSizeInBits();
4451 
4452   // fold (sdiv X, pow2) -> simple ops after legalize
4453   // FIXME: We check for the exact bit here because the generic lowering gives
4454   // better results in that case. The target-specific lowering should learn how
4455   // to handle exact sdivs efficiently.
4456   if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1)) {
4457     // Target-specific implementation of sdiv x, pow2.
4458     if (SDValue Res = BuildSDIVPow2(N))
4459       return Res;
4460 
4461     // Create constants that are functions of the shift amount value.
4462     EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
4463     SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
4464     SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
4465     C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
4466     SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
4467     if (!isConstantOrConstantVector(Inexact))
4468       return SDValue();
4469 
4470     // Splat the sign bit into the register
4471     SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
4472                                DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
4473     AddToWorklist(Sign.getNode());
4474 
4475     // Add (N0 < 0) ? abs2 - 1 : 0;
4476     SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
4477     AddToWorklist(Srl.getNode());
4478     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
4479     AddToWorklist(Add.getNode());
4480     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
4481     AddToWorklist(Sra.getNode());
4482 
4483     // Special case: (sdiv X, 1) -> X
4484     // Special Case: (sdiv X, -1) -> 0-X
4485     SDValue One = DAG.getConstant(1, DL, VT);
4486     SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
4487     SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
4488     SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
4489     SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
4490     Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
4491 
4492     // If dividing by a positive value, we're done. Otherwise, the result must
4493     // be negated.
4494     SDValue Zero = DAG.getConstant(0, DL, VT);
4495     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
4496 
4497     // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
4498     SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
4499     SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
4500     return Res;
4501   }
4502 
4503   // If integer divide is expensive and we satisfy the requirements, emit an
4504   // alternate sequence.  Targets may check function attributes for size/speed
4505   // trade-offs.
4506   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4507   if (isConstantOrConstantVector(N1) &&
4508       !TLI.isIntDivCheap(N->getValueType(0), Attr))
4509     if (SDValue Op = BuildSDIV(N))
4510       return Op;
4511 
4512   return SDValue();
4513 }
4514 
visitUDIV(SDNode * N)4515 SDValue DAGCombiner::visitUDIV(SDNode *N) {
4516   SDValue N0 = N->getOperand(0);
4517   SDValue N1 = N->getOperand(1);
4518   EVT VT = N->getValueType(0);
4519   EVT CCVT = getSetCCResultType(VT);
4520   SDLoc DL(N);
4521 
4522   // fold (udiv c1, c2) -> c1/c2
4523   if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1}))
4524     return C;
4525 
4526   // fold vector ops
4527   if (VT.isVector())
4528     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4529       return FoldedVOp;
4530 
4531   // fold (udiv X, -1) -> select(X == -1, 1, 0)
4532   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4533   if (N1C && N1C->isAllOnes() && CCVT.isVector() == VT.isVector()) {
4534     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4535                          DAG.getConstant(1, DL, VT),
4536                          DAG.getConstant(0, DL, VT));
4537   }
4538 
4539   if (SDValue V = simplifyDivRem(N, DAG))
4540     return V;
4541 
4542   if (SDValue NewSel = foldBinOpIntoSelect(N))
4543     return NewSel;
4544 
4545   if (SDValue V = visitUDIVLike(N0, N1, N)) {
4546     // If the corresponding remainder node exists, update its users with
4547     // (Dividend - (Quotient * Divisor).
4548     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
4549                                               { N0, N1 })) {
4550       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4551       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4552       AddToWorklist(Mul.getNode());
4553       AddToWorklist(Sub.getNode());
4554       CombineTo(RemNode, Sub);
4555     }
4556     return V;
4557   }
4558 
4559   // sdiv, srem -> sdivrem
4560   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4561   // true.  Otherwise, we break the simplification logic in visitREM().
4562   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4563   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4564     if (SDValue DivRem = useDivRem(N))
4565         return DivRem;
4566 
4567   return SDValue();
4568 }
4569 
visitUDIVLike(SDValue N0,SDValue N1,SDNode * N)4570 SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4571   SDLoc DL(N);
4572   EVT VT = N->getValueType(0);
4573 
4574   // fold (udiv x, (1 << c)) -> x >>u c
4575   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4576       DAG.isKnownToBeAPowerOfTwo(N1)) {
4577     SDValue LogBase2 = BuildLogBase2(N1, DL);
4578     AddToWorklist(LogBase2.getNode());
4579 
4580     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4581     SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4582     AddToWorklist(Trunc.getNode());
4583     return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4584   }
4585 
4586   // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
4587   if (N1.getOpcode() == ISD::SHL) {
4588     SDValue N10 = N1.getOperand(0);
4589     if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) &&
4590         DAG.isKnownToBeAPowerOfTwo(N10)) {
4591       SDValue LogBase2 = BuildLogBase2(N10, DL);
4592       AddToWorklist(LogBase2.getNode());
4593 
4594       EVT ADDVT = N1.getOperand(1).getValueType();
4595       SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
4596       AddToWorklist(Trunc.getNode());
4597       SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
4598       AddToWorklist(Add.getNode());
4599       return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
4600     }
4601   }
4602 
4603   // fold (udiv x, c) -> alternate
4604   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4605   if (isConstantOrConstantVector(N1) &&
4606       !TLI.isIntDivCheap(N->getValueType(0), Attr))
4607     if (SDValue Op = BuildUDIV(N))
4608       return Op;
4609 
4610   return SDValue();
4611 }
4612 
buildOptimizedSREM(SDValue N0,SDValue N1,SDNode * N)4613 SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
4614   if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1) &&
4615       !DAG.doesNodeExist(ISD::SDIV, N->getVTList(), {N0, N1})) {
4616     // Target-specific implementation of srem x, pow2.
4617     if (SDValue Res = BuildSREMPow2(N))
4618       return Res;
4619   }
4620   return SDValue();
4621 }
4622 
4623 // handles ISD::SREM and ISD::UREM
visitREM(SDNode * N)4624 SDValue DAGCombiner::visitREM(SDNode *N) {
4625   unsigned Opcode = N->getOpcode();
4626   SDValue N0 = N->getOperand(0);
4627   SDValue N1 = N->getOperand(1);
4628   EVT VT = N->getValueType(0);
4629   EVT CCVT = getSetCCResultType(VT);
4630 
4631   bool isSigned = (Opcode == ISD::SREM);
4632   SDLoc DL(N);
4633 
4634   // fold (rem c1, c2) -> c1%c2
4635   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4636     return C;
4637 
4638   // fold (urem X, -1) -> select(FX == -1, 0, FX)
4639   // Freeze the numerator to avoid a miscompile with an undefined value.
4640   if (!isSigned && llvm::isAllOnesOrAllOnesSplat(N1, /*AllowUndefs*/ false) &&
4641       CCVT.isVector() == VT.isVector()) {
4642     SDValue F0 = DAG.getFreeze(N0);
4643     SDValue EqualsNeg1 = DAG.getSetCC(DL, CCVT, F0, N1, ISD::SETEQ);
4644     return DAG.getSelect(DL, VT, EqualsNeg1, DAG.getConstant(0, DL, VT), F0);
4645   }
4646 
4647   if (SDValue V = simplifyDivRem(N, DAG))
4648     return V;
4649 
4650   if (SDValue NewSel = foldBinOpIntoSelect(N))
4651     return NewSel;
4652 
4653   if (isSigned) {
4654     // If we know the sign bits of both operands are zero, strength reduce to a
4655     // urem instead.  Handles (X & 0x0FFFFFFF) %s 16 -> X&15
4656     if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4657       return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
4658   } else {
4659     if (DAG.isKnownToBeAPowerOfTwo(N1)) {
4660       // fold (urem x, pow2) -> (and x, pow2-1)
4661       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4662       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4663       AddToWorklist(Add.getNode());
4664       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4665     }
4666     // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
4667     // fold (urem x, (lshr pow2, y)) -> (and x, (add (lshr pow2, y), -1))
4668     // TODO: We should sink the following into isKnownToBePowerOfTwo
4669     // using a OrZero parameter analogous to our handling in ValueTracking.
4670     if ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) &&
4671         DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
4672       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4673       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4674       AddToWorklist(Add.getNode());
4675       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4676     }
4677   }
4678 
4679   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4680 
4681   // If X/C can be simplified by the division-by-constant logic, lower
4682   // X%C to the equivalent of X-X/C*C.
4683   // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
4684   // speculative DIV must not cause a DIVREM conversion.  We guard against this
4685   // by skipping the simplification if isIntDivCheap().  When div is not cheap,
4686   // combine will not return a DIVREM.  Regardless, checking cheapness here
4687   // makes sense since the simplification results in fatter code.
4688   if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
4689     if (isSigned) {
4690       // check if we can build faster implementation for srem
4691       if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
4692         return OptimizedRem;
4693     }
4694 
4695     SDValue OptimizedDiv =
4696         isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
4697     if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
4698       // If the equivalent Div node also exists, update its users.
4699       unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4700       if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
4701                                                 { N0, N1 }))
4702         CombineTo(DivNode, OptimizedDiv);
4703       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
4704       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4705       AddToWorklist(OptimizedDiv.getNode());
4706       AddToWorklist(Mul.getNode());
4707       return Sub;
4708     }
4709   }
4710 
4711   // sdiv, srem -> sdivrem
4712   if (SDValue DivRem = useDivRem(N))
4713     return DivRem.getValue(1);
4714 
4715   return SDValue();
4716 }
4717 
visitMULHS(SDNode * N)4718 SDValue DAGCombiner::visitMULHS(SDNode *N) {
4719   SDValue N0 = N->getOperand(0);
4720   SDValue N1 = N->getOperand(1);
4721   EVT VT = N->getValueType(0);
4722   SDLoc DL(N);
4723 
4724   // fold (mulhs c1, c2)
4725   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHS, DL, VT, {N0, N1}))
4726     return C;
4727 
4728   // canonicalize constant to RHS.
4729   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4730       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4731     return DAG.getNode(ISD::MULHS, DL, N->getVTList(), N1, N0);
4732 
4733   if (VT.isVector()) {
4734     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4735       return FoldedVOp;
4736 
4737     // fold (mulhs x, 0) -> 0
4738     // do not return N1, because undef node may exist.
4739     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4740       return DAG.getConstant(0, DL, VT);
4741   }
4742 
4743   // fold (mulhs x, 0) -> 0
4744   if (isNullConstant(N1))
4745     return N1;
4746 
4747   // fold (mulhs x, 1) -> (sra x, size(x)-1)
4748   if (isOneConstant(N1))
4749     return DAG.getNode(ISD::SRA, DL, N0.getValueType(), N0,
4750                        DAG.getConstant(N0.getScalarValueSizeInBits() - 1, DL,
4751                                        getShiftAmountTy(N0.getValueType())));
4752 
4753   // fold (mulhs x, undef) -> 0
4754   if (N0.isUndef() || N1.isUndef())
4755     return DAG.getConstant(0, DL, VT);
4756 
4757   // If the type twice as wide is legal, transform the mulhs to a wider multiply
4758   // plus a shift.
4759   if (!TLI.isOperationLegalOrCustom(ISD::MULHS, VT) && VT.isSimple() &&
4760       !VT.isVector()) {
4761     MVT Simple = VT.getSimpleVT();
4762     unsigned SimpleSize = Simple.getSizeInBits();
4763     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4764     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4765       N0 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
4766       N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
4767       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4768       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4769             DAG.getConstant(SimpleSize, DL,
4770                             getShiftAmountTy(N1.getValueType())));
4771       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4772     }
4773   }
4774 
4775   return SDValue();
4776 }
4777 
visitMULHU(SDNode * N)4778 SDValue DAGCombiner::visitMULHU(SDNode *N) {
4779   SDValue N0 = N->getOperand(0);
4780   SDValue N1 = N->getOperand(1);
4781   EVT VT = N->getValueType(0);
4782   SDLoc DL(N);
4783 
4784   // fold (mulhu c1, c2)
4785   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHU, DL, VT, {N0, N1}))
4786     return C;
4787 
4788   // canonicalize constant to RHS.
4789   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4790       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4791     return DAG.getNode(ISD::MULHU, DL, N->getVTList(), N1, N0);
4792 
4793   if (VT.isVector()) {
4794     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4795       return FoldedVOp;
4796 
4797     // fold (mulhu x, 0) -> 0
4798     // do not return N1, because undef node may exist.
4799     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4800       return DAG.getConstant(0, DL, VT);
4801   }
4802 
4803   // fold (mulhu x, 0) -> 0
4804   if (isNullConstant(N1))
4805     return N1;
4806 
4807   // fold (mulhu x, 1) -> 0
4808   if (isOneConstant(N1))
4809     return DAG.getConstant(0, DL, N0.getValueType());
4810 
4811   // fold (mulhu x, undef) -> 0
4812   if (N0.isUndef() || N1.isUndef())
4813     return DAG.getConstant(0, DL, VT);
4814 
4815   // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
4816   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4817       DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) {
4818     unsigned NumEltBits = VT.getScalarSizeInBits();
4819     SDValue LogBase2 = BuildLogBase2(N1, DL);
4820     SDValue SRLAmt = DAG.getNode(
4821         ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
4822     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4823     SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
4824     return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4825   }
4826 
4827   // If the type twice as wide is legal, transform the mulhu to a wider multiply
4828   // plus a shift.
4829   if (!TLI.isOperationLegalOrCustom(ISD::MULHU, VT) && VT.isSimple() &&
4830       !VT.isVector()) {
4831     MVT Simple = VT.getSimpleVT();
4832     unsigned SimpleSize = Simple.getSizeInBits();
4833     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4834     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4835       N0 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
4836       N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
4837       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4838       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4839             DAG.getConstant(SimpleSize, DL,
4840                             getShiftAmountTy(N1.getValueType())));
4841       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4842     }
4843   }
4844 
4845   // Simplify the operands using demanded-bits information.
4846   // We don't have demanded bits support for MULHU so this just enables constant
4847   // folding based on known bits.
4848   if (SimplifyDemandedBits(SDValue(N, 0)))
4849     return SDValue(N, 0);
4850 
4851   return SDValue();
4852 }
4853 
visitAVG(SDNode * N)4854 SDValue DAGCombiner::visitAVG(SDNode *N) {
4855   unsigned Opcode = N->getOpcode();
4856   SDValue N0 = N->getOperand(0);
4857   SDValue N1 = N->getOperand(1);
4858   EVT VT = N->getValueType(0);
4859   SDLoc DL(N);
4860 
4861   // fold (avg c1, c2)
4862   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4863     return C;
4864 
4865   // canonicalize constant to RHS.
4866   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4867       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4868     return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
4869 
4870   if (VT.isVector()) {
4871     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4872       return FoldedVOp;
4873 
4874     // fold (avgfloor x, 0) -> x >> 1
4875     if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
4876       if (Opcode == ISD::AVGFLOORS)
4877         return DAG.getNode(ISD::SRA, DL, VT, N0, DAG.getConstant(1, DL, VT));
4878       if (Opcode == ISD::AVGFLOORU)
4879         return DAG.getNode(ISD::SRL, DL, VT, N0, DAG.getConstant(1, DL, VT));
4880     }
4881   }
4882 
4883   // fold (avg x, undef) -> x
4884   if (N0.isUndef())
4885     return N1;
4886   if (N1.isUndef())
4887     return N0;
4888 
4889   // TODO If we use avg for scalars anywhere, we can add (avgfl x, 0) -> x >> 1
4890 
4891   return SDValue();
4892 }
4893 
4894 /// Perform optimizations common to nodes that compute two values. LoOp and HiOp
4895 /// give the opcodes for the two computations that are being performed. Return
4896 /// true if a simplification was made.
SimplifyNodeWithTwoResults(SDNode * N,unsigned LoOp,unsigned HiOp)4897 SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
4898                                                 unsigned HiOp) {
4899   // If the high half is not needed, just compute the low half.
4900   bool HiExists = N->hasAnyUseOfValue(1);
4901   if (!HiExists && (!LegalOperations ||
4902                     TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
4903     SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4904     return CombineTo(N, Res, Res);
4905   }
4906 
4907   // If the low half is not needed, just compute the high half.
4908   bool LoExists = N->hasAnyUseOfValue(0);
4909   if (!LoExists && (!LegalOperations ||
4910                     TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
4911     SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4912     return CombineTo(N, Res, Res);
4913   }
4914 
4915   // If both halves are used, return as it is.
4916   if (LoExists && HiExists)
4917     return SDValue();
4918 
4919   // If the two computed results can be simplified separately, separate them.
4920   if (LoExists) {
4921     SDValue Lo = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4922     AddToWorklist(Lo.getNode());
4923     SDValue LoOpt = combine(Lo.getNode());
4924     if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
4925         (!LegalOperations ||
4926          TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
4927       return CombineTo(N, LoOpt, LoOpt);
4928   }
4929 
4930   if (HiExists) {
4931     SDValue Hi = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4932     AddToWorklist(Hi.getNode());
4933     SDValue HiOpt = combine(Hi.getNode());
4934     if (HiOpt.getNode() && HiOpt != Hi &&
4935         (!LegalOperations ||
4936          TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
4937       return CombineTo(N, HiOpt, HiOpt);
4938   }
4939 
4940   return SDValue();
4941 }
4942 
visitSMUL_LOHI(SDNode * N)4943 SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
4944   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
4945     return Res;
4946 
4947   SDValue N0 = N->getOperand(0);
4948   SDValue N1 = N->getOperand(1);
4949   EVT VT = N->getValueType(0);
4950   SDLoc DL(N);
4951 
4952   // canonicalize constant to RHS (vector doesn't have to splat)
4953   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4954       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4955     return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N1, N0);
4956 
4957   // If the type is twice as wide is legal, transform the mulhu to a wider
4958   // multiply plus a shift.
4959   if (VT.isSimple() && !VT.isVector()) {
4960     MVT Simple = VT.getSimpleVT();
4961     unsigned SimpleSize = Simple.getSizeInBits();
4962     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4963     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4964       SDValue Lo = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
4965       SDValue Hi = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
4966       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4967       // Compute the high part as N1.
4968       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4969             DAG.getConstant(SimpleSize, DL,
4970                             getShiftAmountTy(Lo.getValueType())));
4971       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4972       // Compute the low part as N0.
4973       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4974       return CombineTo(N, Lo, Hi);
4975     }
4976   }
4977 
4978   return SDValue();
4979 }
4980 
visitUMUL_LOHI(SDNode * N)4981 SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
4982   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
4983     return Res;
4984 
4985   SDValue N0 = N->getOperand(0);
4986   SDValue N1 = N->getOperand(1);
4987   EVT VT = N->getValueType(0);
4988   SDLoc DL(N);
4989 
4990   // canonicalize constant to RHS (vector doesn't have to splat)
4991   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4992       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4993     return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N1, N0);
4994 
4995   // (umul_lohi N0, 0) -> (0, 0)
4996   if (isNullConstant(N1)) {
4997     SDValue Zero = DAG.getConstant(0, DL, VT);
4998     return CombineTo(N, Zero, Zero);
4999   }
5000 
5001   // (umul_lohi N0, 1) -> (N0, 0)
5002   if (isOneConstant(N1)) {
5003     SDValue Zero = DAG.getConstant(0, DL, VT);
5004     return CombineTo(N, N0, Zero);
5005   }
5006 
5007   // If the type is twice as wide is legal, transform the mulhu to a wider
5008   // multiply plus a shift.
5009   if (VT.isSimple() && !VT.isVector()) {
5010     MVT Simple = VT.getSimpleVT();
5011     unsigned SimpleSize = Simple.getSizeInBits();
5012     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5013     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5014       SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
5015       SDValue Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
5016       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
5017       // Compute the high part as N1.
5018       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
5019             DAG.getConstant(SimpleSize, DL,
5020                             getShiftAmountTy(Lo.getValueType())));
5021       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
5022       // Compute the low part as N0.
5023       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
5024       return CombineTo(N, Lo, Hi);
5025     }
5026   }
5027 
5028   return SDValue();
5029 }
5030 
visitMULO(SDNode * N)5031 SDValue DAGCombiner::visitMULO(SDNode *N) {
5032   SDValue N0 = N->getOperand(0);
5033   SDValue N1 = N->getOperand(1);
5034   EVT VT = N0.getValueType();
5035   bool IsSigned = (ISD::SMULO == N->getOpcode());
5036 
5037   EVT CarryVT = N->getValueType(1);
5038   SDLoc DL(N);
5039 
5040   ConstantSDNode *N0C = isConstOrConstSplat(N0);
5041   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5042 
5043   // fold operation with constant operands.
5044   // TODO: Move this to FoldConstantArithmetic when it supports nodes with
5045   // multiple results.
5046   if (N0C && N1C) {
5047     bool Overflow;
5048     APInt Result =
5049         IsSigned ? N0C->getAPIntValue().smul_ov(N1C->getAPIntValue(), Overflow)
5050                  : N0C->getAPIntValue().umul_ov(N1C->getAPIntValue(), Overflow);
5051     return CombineTo(N, DAG.getConstant(Result, DL, VT),
5052                      DAG.getBoolConstant(Overflow, DL, CarryVT, CarryVT));
5053   }
5054 
5055   // canonicalize constant to RHS.
5056   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5057       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5058     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
5059 
5060   // fold (mulo x, 0) -> 0 + no carry out
5061   if (isNullOrNullSplat(N1))
5062     return CombineTo(N, DAG.getConstant(0, DL, VT),
5063                      DAG.getConstant(0, DL, CarryVT));
5064 
5065   // (mulo x, 2) -> (addo x, x)
5066   // FIXME: This needs a freeze.
5067   if (N1C && N1C->getAPIntValue() == 2 &&
5068       (!IsSigned || VT.getScalarSizeInBits() > 2))
5069     return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
5070                        N->getVTList(), N0, N0);
5071 
5072   if (IsSigned) {
5073     // A 1 bit SMULO overflows if both inputs are 1.
5074     if (VT.getScalarSizeInBits() == 1) {
5075       SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1);
5076       return CombineTo(N, And,
5077                        DAG.getSetCC(DL, CarryVT, And,
5078                                     DAG.getConstant(0, DL, VT), ISD::SETNE));
5079     }
5080 
5081     // Multiplying n * m significant bits yields a result of n + m significant
5082     // bits. If the total number of significant bits does not exceed the
5083     // result bit width (minus 1), there is no overflow.
5084     unsigned SignBits = DAG.ComputeNumSignBits(N0);
5085     if (SignBits > 1)
5086       SignBits += DAG.ComputeNumSignBits(N1);
5087     if (SignBits > VT.getScalarSizeInBits() + 1)
5088       return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
5089                        DAG.getConstant(0, DL, CarryVT));
5090   } else {
5091     KnownBits N1Known = DAG.computeKnownBits(N1);
5092     KnownBits N0Known = DAG.computeKnownBits(N0);
5093     bool Overflow;
5094     (void)N0Known.getMaxValue().umul_ov(N1Known.getMaxValue(), Overflow);
5095     if (!Overflow)
5096       return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
5097                        DAG.getConstant(0, DL, CarryVT));
5098   }
5099 
5100   return SDValue();
5101 }
5102 
5103 // Function to calculate whether the Min/Max pair of SDNodes (potentially
5104 // swapped around) make a signed saturate pattern, clamping to between a signed
5105 // saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
5106 // Returns the node being clamped and the bitwidth of the clamp in BW. Should
5107 // work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
5108 // same as SimplifySelectCC. N0<N1 ? N2 : N3.
isSaturatingMinMax(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,unsigned & BW,bool & Unsigned)5109 static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
5110                                   SDValue N3, ISD::CondCode CC, unsigned &BW,
5111                                   bool &Unsigned) {
5112   auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
5113                             ISD::CondCode CC) {
5114     // The compare and select operand should be the same or the select operands
5115     // should be truncated versions of the comparison.
5116     if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0)))
5117       return 0;
5118     // The constants need to be the same or a truncated version of each other.
5119     ConstantSDNode *N1C = isConstOrConstSplat(N1);
5120     ConstantSDNode *N3C = isConstOrConstSplat(N3);
5121     if (!N1C || !N3C)
5122       return 0;
5123     const APInt &C1 = N1C->getAPIntValue();
5124     const APInt &C2 = N3C->getAPIntValue();
5125     if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(C1.getBitWidth()))
5126       return 0;
5127     return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
5128   };
5129 
5130   // Check the initial value is a SMIN/SMAX equivalent.
5131   unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
5132   if (!Opcode0)
5133     return SDValue();
5134 
5135   SDValue N00, N01, N02, N03;
5136   ISD::CondCode N0CC;
5137   switch (N0.getOpcode()) {
5138   case ISD::SMIN:
5139   case ISD::SMAX:
5140     N00 = N02 = N0.getOperand(0);
5141     N01 = N03 = N0.getOperand(1);
5142     N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
5143     break;
5144   case ISD::SELECT_CC:
5145     N00 = N0.getOperand(0);
5146     N01 = N0.getOperand(1);
5147     N02 = N0.getOperand(2);
5148     N03 = N0.getOperand(3);
5149     N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
5150     break;
5151   case ISD::SELECT:
5152   case ISD::VSELECT:
5153     if (N0.getOperand(0).getOpcode() != ISD::SETCC)
5154       return SDValue();
5155     N00 = N0.getOperand(0).getOperand(0);
5156     N01 = N0.getOperand(0).getOperand(1);
5157     N02 = N0.getOperand(1);
5158     N03 = N0.getOperand(2);
5159     N0CC = cast<CondCodeSDNode>(N0.getOperand(0).getOperand(2))->get();
5160     break;
5161   default:
5162     return SDValue();
5163   }
5164 
5165   unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
5166   if (!Opcode1 || Opcode0 == Opcode1)
5167     return SDValue();
5168 
5169   ConstantSDNode *MinCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N1 : N01);
5170   ConstantSDNode *MaxCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N01 : N1);
5171   if (!MinCOp || !MaxCOp || MinCOp->getValueType(0) != MaxCOp->getValueType(0))
5172     return SDValue();
5173 
5174   const APInt &MinC = MinCOp->getAPIntValue();
5175   const APInt &MaxC = MaxCOp->getAPIntValue();
5176   APInt MinCPlus1 = MinC + 1;
5177   if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
5178     BW = MinCPlus1.exactLogBase2() + 1;
5179     Unsigned = false;
5180     return N02;
5181   }
5182 
5183   if (MaxC == 0 && MinCPlus1.isPowerOf2()) {
5184     BW = MinCPlus1.exactLogBase2();
5185     Unsigned = true;
5186     return N02;
5187   }
5188 
5189   return SDValue();
5190 }
5191 
PerformMinMaxFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5192 static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5193                                            SDValue N3, ISD::CondCode CC,
5194                                            SelectionDAG &DAG) {
5195   unsigned BW;
5196   bool Unsigned;
5197   SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned);
5198   if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
5199     return SDValue();
5200   EVT FPVT = Fp.getOperand(0).getValueType();
5201   EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
5202   if (FPVT.isVector())
5203     NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
5204                              FPVT.getVectorElementCount());
5205   unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
5206   if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(NewOpc, FPVT, NewVT))
5207     return SDValue();
5208   SDLoc DL(Fp);
5209   SDValue Sat = DAG.getNode(NewOpc, DL, NewVT, Fp.getOperand(0),
5210                             DAG.getValueType(NewVT.getScalarType()));
5211   return Unsigned ? DAG.getZExtOrTrunc(Sat, DL, N2->getValueType(0))
5212                   : DAG.getSExtOrTrunc(Sat, DL, N2->getValueType(0));
5213 }
5214 
PerformUMinFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5215 static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5216                                          SDValue N3, ISD::CondCode CC,
5217                                          SelectionDAG &DAG) {
5218   // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
5219   // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
5220   // be truncated versions of the the setcc (N0/N1).
5221   if ((N0 != N2 &&
5222        (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0))) ||
5223       N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
5224     return SDValue();
5225   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5226   ConstantSDNode *N3C = isConstOrConstSplat(N3);
5227   if (!N1C || !N3C)
5228     return SDValue();
5229   const APInt &C1 = N1C->getAPIntValue();
5230   const APInt &C3 = N3C->getAPIntValue();
5231   if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
5232       C1 != C3.zext(C1.getBitWidth()))
5233     return SDValue();
5234 
5235   unsigned BW = (C1 + 1).exactLogBase2();
5236   EVT FPVT = N0.getOperand(0).getValueType();
5237   EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
5238   if (FPVT.isVector())
5239     NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
5240                              FPVT.getVectorElementCount());
5241   if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
5242                                                         FPVT, NewVT))
5243     return SDValue();
5244 
5245   SDValue Sat =
5246       DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(N0), NewVT, N0.getOperand(0),
5247                   DAG.getValueType(NewVT.getScalarType()));
5248   return DAG.getZExtOrTrunc(Sat, SDLoc(N0), N3.getValueType());
5249 }
5250 
visitIMINMAX(SDNode * N)5251 SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
5252   SDValue N0 = N->getOperand(0);
5253   SDValue N1 = N->getOperand(1);
5254   EVT VT = N0.getValueType();
5255   unsigned Opcode = N->getOpcode();
5256   SDLoc DL(N);
5257 
5258   // fold operation with constant operands.
5259   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5260     return C;
5261 
5262   // If the operands are the same, this is a no-op.
5263   if (N0 == N1)
5264     return N0;
5265 
5266   // canonicalize constant to RHS
5267   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5268       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5269     return DAG.getNode(Opcode, DL, VT, N1, N0);
5270 
5271   // fold vector ops
5272   if (VT.isVector())
5273     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5274       return FoldedVOp;
5275 
5276   // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
5277   // Only do this if the current op isn't legal and the flipped is.
5278   if (!TLI.isOperationLegal(Opcode, VT) &&
5279       (N0.isUndef() || DAG.SignBitIsZero(N0)) &&
5280       (N1.isUndef() || DAG.SignBitIsZero(N1))) {
5281     unsigned AltOpcode;
5282     switch (Opcode) {
5283     case ISD::SMIN: AltOpcode = ISD::UMIN; break;
5284     case ISD::SMAX: AltOpcode = ISD::UMAX; break;
5285     case ISD::UMIN: AltOpcode = ISD::SMIN; break;
5286     case ISD::UMAX: AltOpcode = ISD::SMAX; break;
5287     default: llvm_unreachable("Unknown MINMAX opcode");
5288     }
5289     if (TLI.isOperationLegal(AltOpcode, VT))
5290       return DAG.getNode(AltOpcode, DL, VT, N0, N1);
5291   }
5292 
5293   if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
5294     if (SDValue S = PerformMinMaxFpToSatCombine(
5295             N0, N1, N0, N1, Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
5296       return S;
5297   if (Opcode == ISD::UMIN)
5298     if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N0, N1, ISD::SETULT, DAG))
5299       return S;
5300 
5301   // Simplify the operands using demanded-bits information.
5302   if (SimplifyDemandedBits(SDValue(N, 0)))
5303     return SDValue(N, 0);
5304 
5305   return SDValue();
5306 }
5307 
5308 /// If this is a bitwise logic instruction and both operands have the same
5309 /// opcode, try to sink the other opcode after the logic instruction.
hoistLogicOpWithSameOpcodeHands(SDNode * N)5310 SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
5311   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
5312   EVT VT = N0.getValueType();
5313   unsigned LogicOpcode = N->getOpcode();
5314   unsigned HandOpcode = N0.getOpcode();
5315   assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
5316           LogicOpcode == ISD::XOR) && "Expected logic opcode");
5317   assert(HandOpcode == N1.getOpcode() && "Bad input!");
5318 
5319   // Bail early if none of these transforms apply.
5320   if (N0.getNumOperands() == 0)
5321     return SDValue();
5322 
5323   // FIXME: We should check number of uses of the operands to not increase
5324   //        the instruction count for all transforms.
5325 
5326   // Handle size-changing casts.
5327   SDValue X = N0.getOperand(0);
5328   SDValue Y = N1.getOperand(0);
5329   EVT XVT = X.getValueType();
5330   SDLoc DL(N);
5331   if (HandOpcode == ISD::ANY_EXTEND || HandOpcode == ISD::ZERO_EXTEND ||
5332       HandOpcode == ISD::SIGN_EXTEND) {
5333     // If both operands have other uses, this transform would create extra
5334     // instructions without eliminating anything.
5335     if (!N0.hasOneUse() && !N1.hasOneUse())
5336       return SDValue();
5337     // We need matching integer source types.
5338     if (XVT != Y.getValueType())
5339       return SDValue();
5340     // Don't create an illegal op during or after legalization. Don't ever
5341     // create an unsupported vector op.
5342     if ((VT.isVector() || LegalOperations) &&
5343         !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
5344       return SDValue();
5345     // Avoid infinite looping with PromoteIntBinOp.
5346     // TODO: Should we apply desirable/legal constraints to all opcodes?
5347     if (HandOpcode == ISD::ANY_EXTEND && LegalTypes &&
5348         !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
5349       return SDValue();
5350     // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
5351     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5352     return DAG.getNode(HandOpcode, DL, VT, Logic);
5353   }
5354 
5355   // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
5356   if (HandOpcode == ISD::TRUNCATE) {
5357     // If both operands have other uses, this transform would create extra
5358     // instructions without eliminating anything.
5359     if (!N0.hasOneUse() && !N1.hasOneUse())
5360       return SDValue();
5361     // We need matching source types.
5362     if (XVT != Y.getValueType())
5363       return SDValue();
5364     // Don't create an illegal op during or after legalization.
5365     if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
5366       return SDValue();
5367     // Be extra careful sinking truncate. If it's free, there's no benefit in
5368     // widening a binop. Also, don't create a logic op on an illegal type.
5369     if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
5370       return SDValue();
5371     if (!TLI.isTypeLegal(XVT))
5372       return SDValue();
5373     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5374     return DAG.getNode(HandOpcode, DL, VT, Logic);
5375   }
5376 
5377   // For binops SHL/SRL/SRA/AND:
5378   //   logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
5379   if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
5380        HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
5381       N0.getOperand(1) == N1.getOperand(1)) {
5382     // If either operand has other uses, this transform is not an improvement.
5383     if (!N0.hasOneUse() || !N1.hasOneUse())
5384       return SDValue();
5385     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5386     return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
5387   }
5388 
5389   // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
5390   if (HandOpcode == ISD::BSWAP) {
5391     // If either operand has other uses, this transform is not an improvement.
5392     if (!N0.hasOneUse() || !N1.hasOneUse())
5393       return SDValue();
5394     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5395     return DAG.getNode(HandOpcode, DL, VT, Logic);
5396   }
5397 
5398   // For funnel shifts FSHL/FSHR:
5399   // logic_op (OP x, x1, s), (OP y, y1, s) -->
5400   // --> OP (logic_op x, y), (logic_op, x1, y1), s
5401   if ((HandOpcode == ISD::FSHL || HandOpcode == ISD::FSHR) &&
5402       N0.getOperand(2) == N1.getOperand(2)) {
5403     if (!N0.hasOneUse() || !N1.hasOneUse())
5404       return SDValue();
5405     SDValue X1 = N0.getOperand(1);
5406     SDValue Y1 = N1.getOperand(1);
5407     SDValue S = N0.getOperand(2);
5408     SDValue Logic0 = DAG.getNode(LogicOpcode, DL, VT, X, Y);
5409     SDValue Logic1 = DAG.getNode(LogicOpcode, DL, VT, X1, Y1);
5410     return DAG.getNode(HandOpcode, DL, VT, Logic0, Logic1, S);
5411   }
5412 
5413   // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
5414   // Only perform this optimization up until type legalization, before
5415   // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
5416   // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
5417   // we don't want to undo this promotion.
5418   // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
5419   // on scalars.
5420   if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
5421        Level <= AfterLegalizeTypes) {
5422     // Input types must be integer and the same.
5423     if (XVT.isInteger() && XVT == Y.getValueType() &&
5424         !(VT.isVector() && TLI.isTypeLegal(VT) &&
5425           !XVT.isVector() && !TLI.isTypeLegal(XVT))) {
5426       SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5427       return DAG.getNode(HandOpcode, DL, VT, Logic);
5428     }
5429   }
5430 
5431   // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
5432   // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
5433   // If both shuffles use the same mask, and both shuffle within a single
5434   // vector, then it is worthwhile to move the swizzle after the operation.
5435   // The type-legalizer generates this pattern when loading illegal
5436   // vector types from memory. In many cases this allows additional shuffle
5437   // optimizations.
5438   // There are other cases where moving the shuffle after the xor/and/or
5439   // is profitable even if shuffles don't perform a swizzle.
5440   // If both shuffles use the same mask, and both shuffles have the same first
5441   // or second operand, then it might still be profitable to move the shuffle
5442   // after the xor/and/or operation.
5443   if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
5444     auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
5445     auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
5446     assert(X.getValueType() == Y.getValueType() &&
5447            "Inputs to shuffles are not the same type");
5448 
5449     // Check that both shuffles use the same mask. The masks are known to be of
5450     // the same length because the result vector type is the same.
5451     // Check also that shuffles have only one use to avoid introducing extra
5452     // instructions.
5453     if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
5454         !SVN0->getMask().equals(SVN1->getMask()))
5455       return SDValue();
5456 
5457     // Don't try to fold this node if it requires introducing a
5458     // build vector of all zeros that might be illegal at this stage.
5459     SDValue ShOp = N0.getOperand(1);
5460     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5461       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5462 
5463     // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
5464     if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
5465       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
5466                                   N0.getOperand(0), N1.getOperand(0));
5467       return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
5468     }
5469 
5470     // Don't try to fold this node if it requires introducing a
5471     // build vector of all zeros that might be illegal at this stage.
5472     ShOp = N0.getOperand(0);
5473     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5474       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5475 
5476     // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
5477     if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
5478       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
5479                                   N1.getOperand(1));
5480       return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
5481     }
5482   }
5483 
5484   return SDValue();
5485 }
5486 
5487 /// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
foldLogicOfSetCCs(bool IsAnd,SDValue N0,SDValue N1,const SDLoc & DL)5488 SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
5489                                        const SDLoc &DL) {
5490   SDValue LL, LR, RL, RR, N0CC, N1CC;
5491   if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
5492       !isSetCCEquivalent(N1, RL, RR, N1CC))
5493     return SDValue();
5494 
5495   assert(N0.getValueType() == N1.getValueType() &&
5496          "Unexpected operand types for bitwise logic op");
5497   assert(LL.getValueType() == LR.getValueType() &&
5498          RL.getValueType() == RR.getValueType() &&
5499          "Unexpected operand types for setcc");
5500 
5501   // If we're here post-legalization or the logic op type is not i1, the logic
5502   // op type must match a setcc result type. Also, all folds require new
5503   // operations on the left and right operands, so those types must match.
5504   EVT VT = N0.getValueType();
5505   EVT OpVT = LL.getValueType();
5506   if (LegalOperations || VT.getScalarType() != MVT::i1)
5507     if (VT != getSetCCResultType(OpVT))
5508       return SDValue();
5509   if (OpVT != RL.getValueType())
5510     return SDValue();
5511 
5512   ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
5513   ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
5514   bool IsInteger = OpVT.isInteger();
5515   if (LR == RR && CC0 == CC1 && IsInteger) {
5516     bool IsZero = isNullOrNullSplat(LR);
5517     bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
5518 
5519     // All bits clear?
5520     bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
5521     // All sign bits clear?
5522     bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
5523     // Any bits set?
5524     bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
5525     // Any sign bits set?
5526     bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
5527 
5528     // (and (seteq X,  0), (seteq Y,  0)) --> (seteq (or X, Y),  0)
5529     // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
5530     // (or  (setne X,  0), (setne Y,  0)) --> (setne (or X, Y),  0)
5531     // (or  (setlt X,  0), (setlt Y,  0)) --> (setlt (or X, Y),  0)
5532     if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
5533       SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
5534       AddToWorklist(Or.getNode());
5535       return DAG.getSetCC(DL, VT, Or, LR, CC1);
5536     }
5537 
5538     // All bits set?
5539     bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
5540     // All sign bits set?
5541     bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
5542     // Any bits clear?
5543     bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
5544     // Any sign bits clear?
5545     bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
5546 
5547     // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
5548     // (and (setlt X,  0), (setlt Y,  0)) --> (setlt (and X, Y),  0)
5549     // (or  (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
5550     // (or  (setgt X, -1), (setgt Y  -1)) --> (setgt (and X, Y), -1)
5551     if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
5552       SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
5553       AddToWorklist(And.getNode());
5554       return DAG.getSetCC(DL, VT, And, LR, CC1);
5555     }
5556   }
5557 
5558   // TODO: What is the 'or' equivalent of this fold?
5559   // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
5560   if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
5561       IsInteger && CC0 == ISD::SETNE &&
5562       ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
5563        (isAllOnesConstant(LR) && isNullConstant(RR)))) {
5564     SDValue One = DAG.getConstant(1, DL, OpVT);
5565     SDValue Two = DAG.getConstant(2, DL, OpVT);
5566     SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
5567     AddToWorklist(Add.getNode());
5568     return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
5569   }
5570 
5571   // Try more general transforms if the predicates match and the only user of
5572   // the compares is the 'and' or 'or'.
5573   if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
5574       N0.hasOneUse() && N1.hasOneUse()) {
5575     // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
5576     // or  (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
5577     if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
5578       SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
5579       SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
5580       SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
5581       SDValue Zero = DAG.getConstant(0, DL, OpVT);
5582       return DAG.getSetCC(DL, VT, Or, Zero, CC1);
5583     }
5584 
5585     // Turn compare of constants whose difference is 1 bit into add+and+setcc.
5586     if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
5587       // Match a shared variable operand and 2 non-opaque constant operands.
5588       auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
5589         // The difference of the constants must be a single bit.
5590         const APInt &CMax =
5591             APIntOps::umax(C0->getAPIntValue(), C1->getAPIntValue());
5592         const APInt &CMin =
5593             APIntOps::umin(C0->getAPIntValue(), C1->getAPIntValue());
5594         return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
5595       };
5596       if (LL == RL && ISD::matchBinaryPredicate(LR, RR, MatchDiffPow2)) {
5597         // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
5598         // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
5599         SDValue Max = DAG.getNode(ISD::UMAX, DL, OpVT, LR, RR);
5600         SDValue Min = DAG.getNode(ISD::UMIN, DL, OpVT, LR, RR);
5601         SDValue Offset = DAG.getNode(ISD::SUB, DL, OpVT, LL, Min);
5602         SDValue Diff = DAG.getNode(ISD::SUB, DL, OpVT, Max, Min);
5603         SDValue Mask = DAG.getNOT(DL, Diff, OpVT);
5604         SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Offset, Mask);
5605         SDValue Zero = DAG.getConstant(0, DL, OpVT);
5606         return DAG.getSetCC(DL, VT, And, Zero, CC0);
5607       }
5608     }
5609   }
5610 
5611   // Canonicalize equivalent operands to LL == RL.
5612   if (LL == RR && LR == RL) {
5613     CC1 = ISD::getSetCCSwappedOperands(CC1);
5614     std::swap(RL, RR);
5615   }
5616 
5617   // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
5618   // (or  (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
5619   if (LL == RL && LR == RR) {
5620     ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, OpVT)
5621                                 : ISD::getSetCCOrOperation(CC0, CC1, OpVT);
5622     if (NewCC != ISD::SETCC_INVALID &&
5623         (!LegalOperations ||
5624          (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
5625           TLI.isOperationLegal(ISD::SETCC, OpVT))))
5626       return DAG.getSetCC(DL, VT, LL, LR, NewCC);
5627   }
5628 
5629   return SDValue();
5630 }
5631 
5632 /// This contains all DAGCombine rules which reduce two values combined by
5633 /// an And operation to a single value. This makes them reusable in the context
5634 /// of visitSELECT(). Rules involving constants are not included as
5635 /// visitSELECT() already handles those cases.
visitANDLike(SDValue N0,SDValue N1,SDNode * N)5636 SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
5637   EVT VT = N1.getValueType();
5638   SDLoc DL(N);
5639 
5640   // fold (and x, undef) -> 0
5641   if (N0.isUndef() || N1.isUndef())
5642     return DAG.getConstant(0, DL, VT);
5643 
5644   if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
5645     return V;
5646 
5647   // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
5648   if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
5649       VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
5650     if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5651       if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
5652         // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
5653         // immediate for an add, but it is legal if its top c2 bits are set,
5654         // transform the ADD so the immediate doesn't need to be materialized
5655         // in a register.
5656         APInt ADDC = ADDI->getAPIntValue();
5657         APInt SRLC = SRLI->getAPIntValue();
5658         if (ADDC.getMinSignedBits() <= 64 &&
5659             SRLC.ult(VT.getSizeInBits()) &&
5660             !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
5661           APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(),
5662                                              SRLC.getZExtValue());
5663           if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
5664             ADDC |= Mask;
5665             if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
5666               SDLoc DL0(N0);
5667               SDValue NewAdd =
5668                 DAG.getNode(ISD::ADD, DL0, VT,
5669                             N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
5670               CombineTo(N0.getNode(), NewAdd);
5671               // Return N so it doesn't get rechecked!
5672               return SDValue(N, 0);
5673             }
5674           }
5675         }
5676       }
5677     }
5678   }
5679 
5680   // Reduce bit extract of low half of an integer to the narrower type.
5681   // (and (srl i64:x, K), KMask) ->
5682   //   (i64 zero_extend (and (srl (i32 (trunc i64:x)), K)), KMask)
5683   if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
5684     if (ConstantSDNode *CAnd = dyn_cast<ConstantSDNode>(N1)) {
5685       if (ConstantSDNode *CShift = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5686         unsigned Size = VT.getSizeInBits();
5687         const APInt &AndMask = CAnd->getAPIntValue();
5688         unsigned ShiftBits = CShift->getZExtValue();
5689 
5690         // Bail out, this node will probably disappear anyway.
5691         if (ShiftBits == 0)
5692           return SDValue();
5693 
5694         unsigned MaskBits = AndMask.countTrailingOnes();
5695         EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), Size / 2);
5696 
5697         if (AndMask.isMask() &&
5698             // Required bits must not span the two halves of the integer and
5699             // must fit in the half size type.
5700             (ShiftBits + MaskBits <= Size / 2) &&
5701             TLI.isNarrowingProfitable(VT, HalfVT) &&
5702             TLI.isTypeDesirableForOp(ISD::AND, HalfVT) &&
5703             TLI.isTypeDesirableForOp(ISD::SRL, HalfVT) &&
5704             TLI.isTruncateFree(VT, HalfVT) &&
5705             TLI.isZExtFree(HalfVT, VT)) {
5706           // The isNarrowingProfitable is to avoid regressions on PPC and
5707           // AArch64 which match a few 64-bit bit insert / bit extract patterns
5708           // on downstream users of this. Those patterns could probably be
5709           // extended to handle extensions mixed in.
5710 
5711           SDValue SL(N0);
5712           assert(MaskBits <= Size);
5713 
5714           // Extracting the highest bit of the low half.
5715           EVT ShiftVT = TLI.getShiftAmountTy(HalfVT, DAG.getDataLayout());
5716           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, HalfVT,
5717                                       N0.getOperand(0));
5718 
5719           SDValue NewMask = DAG.getConstant(AndMask.trunc(Size / 2), SL, HalfVT);
5720           SDValue ShiftK = DAG.getConstant(ShiftBits, SL, ShiftVT);
5721           SDValue Shift = DAG.getNode(ISD::SRL, SL, HalfVT, Trunc, ShiftK);
5722           SDValue And = DAG.getNode(ISD::AND, SL, HalfVT, Shift, NewMask);
5723           return DAG.getNode(ISD::ZERO_EXTEND, SL, VT, And);
5724         }
5725       }
5726     }
5727   }
5728 
5729   return SDValue();
5730 }
5731 
isAndLoadExtLoad(ConstantSDNode * AndC,LoadSDNode * LoadN,EVT LoadResultTy,EVT & ExtVT)5732 bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
5733                                    EVT LoadResultTy, EVT &ExtVT) {
5734   if (!AndC->getAPIntValue().isMask())
5735     return false;
5736 
5737   unsigned ActiveBits = AndC->getAPIntValue().countTrailingOnes();
5738 
5739   ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
5740   EVT LoadedVT = LoadN->getMemoryVT();
5741 
5742   if (ExtVT == LoadedVT &&
5743       (!LegalOperations ||
5744        TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
5745     // ZEXTLOAD will match without needing to change the size of the value being
5746     // loaded.
5747     return true;
5748   }
5749 
5750   // Do not change the width of a volatile or atomic loads.
5751   if (!LoadN->isSimple())
5752     return false;
5753 
5754   // Do not generate loads of non-round integer types since these can
5755   // be expensive (and would be wrong if the type is not byte sized).
5756   if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
5757     return false;
5758 
5759   if (LegalOperations &&
5760       !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
5761     return false;
5762 
5763   if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT))
5764     return false;
5765 
5766   return true;
5767 }
5768 
isLegalNarrowLdSt(LSBaseSDNode * LDST,ISD::LoadExtType ExtType,EVT & MemVT,unsigned ShAmt)5769 bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
5770                                     ISD::LoadExtType ExtType, EVT &MemVT,
5771                                     unsigned ShAmt) {
5772   if (!LDST)
5773     return false;
5774   // Only allow byte offsets.
5775   if (ShAmt % 8)
5776     return false;
5777 
5778   // Do not generate loads of non-round integer types since these can
5779   // be expensive (and would be wrong if the type is not byte sized).
5780   if (!MemVT.isRound())
5781     return false;
5782 
5783   // Don't change the width of a volatile or atomic loads.
5784   if (!LDST->isSimple())
5785     return false;
5786 
5787   EVT LdStMemVT = LDST->getMemoryVT();
5788 
5789   // Bail out when changing the scalable property, since we can't be sure that
5790   // we're actually narrowing here.
5791   if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
5792     return false;
5793 
5794   // Verify that we are actually reducing a load width here.
5795   if (LdStMemVT.bitsLT(MemVT))
5796     return false;
5797 
5798   // Ensure that this isn't going to produce an unsupported memory access.
5799   if (ShAmt) {
5800     assert(ShAmt % 8 == 0 && "ShAmt is byte offset");
5801     const unsigned ByteShAmt = ShAmt / 8;
5802     const Align LDSTAlign = LDST->getAlign();
5803     const Align NarrowAlign = commonAlignment(LDSTAlign, ByteShAmt);
5804     if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
5805                                 LDST->getAddressSpace(), NarrowAlign,
5806                                 LDST->getMemOperand()->getFlags()))
5807       return false;
5808   }
5809 
5810   // It's not possible to generate a constant of extended or untyped type.
5811   EVT PtrType = LDST->getBasePtr().getValueType();
5812   if (PtrType == MVT::Untyped || PtrType.isExtended())
5813     return false;
5814 
5815   if (isa<LoadSDNode>(LDST)) {
5816     LoadSDNode *Load = cast<LoadSDNode>(LDST);
5817     // Don't transform one with multiple uses, this would require adding a new
5818     // load.
5819     if (!SDValue(Load, 0).hasOneUse())
5820       return false;
5821 
5822     if (LegalOperations &&
5823         !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
5824       return false;
5825 
5826     // For the transform to be legal, the load must produce only two values
5827     // (the value loaded and the chain).  Don't transform a pre-increment
5828     // load, for example, which produces an extra value.  Otherwise the
5829     // transformation is not equivalent, and the downstream logic to replace
5830     // uses gets things wrong.
5831     if (Load->getNumValues() > 2)
5832       return false;
5833 
5834     // If the load that we're shrinking is an extload and we're not just
5835     // discarding the extension we can't simply shrink the load. Bail.
5836     // TODO: It would be possible to merge the extensions in some cases.
5837     if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
5838         Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
5839       return false;
5840 
5841     if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT))
5842       return false;
5843   } else {
5844     assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
5845     StoreSDNode *Store = cast<StoreSDNode>(LDST);
5846     // Can't write outside the original store
5847     if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
5848       return false;
5849 
5850     if (LegalOperations &&
5851         !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
5852       return false;
5853   }
5854   return true;
5855 }
5856 
SearchForAndLoads(SDNode * N,SmallVectorImpl<LoadSDNode * > & Loads,SmallPtrSetImpl<SDNode * > & NodesWithConsts,ConstantSDNode * Mask,SDNode * & NodeToMask)5857 bool DAGCombiner::SearchForAndLoads(SDNode *N,
5858                                     SmallVectorImpl<LoadSDNode*> &Loads,
5859                                     SmallPtrSetImpl<SDNode*> &NodesWithConsts,
5860                                     ConstantSDNode *Mask,
5861                                     SDNode *&NodeToMask) {
5862   // Recursively search for the operands, looking for loads which can be
5863   // narrowed.
5864   for (SDValue Op : N->op_values()) {
5865     if (Op.getValueType().isVector())
5866       return false;
5867 
5868     // Some constants may need fixing up later if they are too large.
5869     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
5870       if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
5871           (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
5872         NodesWithConsts.insert(N);
5873       continue;
5874     }
5875 
5876     if (!Op.hasOneUse())
5877       return false;
5878 
5879     switch(Op.getOpcode()) {
5880     case ISD::LOAD: {
5881       auto *Load = cast<LoadSDNode>(Op);
5882       EVT ExtVT;
5883       if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
5884           isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
5885 
5886         // ZEXTLOAD is already small enough.
5887         if (Load->getExtensionType() == ISD::ZEXTLOAD &&
5888             ExtVT.bitsGE(Load->getMemoryVT()))
5889           continue;
5890 
5891         // Use LE to convert equal sized loads to zext.
5892         if (ExtVT.bitsLE(Load->getMemoryVT()))
5893           Loads.push_back(Load);
5894 
5895         continue;
5896       }
5897       return false;
5898     }
5899     case ISD::ZERO_EXTEND:
5900     case ISD::AssertZext: {
5901       unsigned ActiveBits = Mask->getAPIntValue().countTrailingOnes();
5902       EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
5903       EVT VT = Op.getOpcode() == ISD::AssertZext ?
5904         cast<VTSDNode>(Op.getOperand(1))->getVT() :
5905         Op.getOperand(0).getValueType();
5906 
5907       // We can accept extending nodes if the mask is wider or an equal
5908       // width to the original type.
5909       if (ExtVT.bitsGE(VT))
5910         continue;
5911       break;
5912     }
5913     case ISD::OR:
5914     case ISD::XOR:
5915     case ISD::AND:
5916       if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
5917                              NodeToMask))
5918         return false;
5919       continue;
5920     }
5921 
5922     // Allow one node which will masked along with any loads found.
5923     if (NodeToMask)
5924       return false;
5925 
5926     // Also ensure that the node to be masked only produces one data result.
5927     NodeToMask = Op.getNode();
5928     if (NodeToMask->getNumValues() > 1) {
5929       bool HasValue = false;
5930       for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
5931         MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
5932         if (VT != MVT::Glue && VT != MVT::Other) {
5933           if (HasValue) {
5934             NodeToMask = nullptr;
5935             return false;
5936           }
5937           HasValue = true;
5938         }
5939       }
5940       assert(HasValue && "Node to be masked has no data result?");
5941     }
5942   }
5943   return true;
5944 }
5945 
BackwardsPropagateMask(SDNode * N)5946 bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
5947   auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
5948   if (!Mask)
5949     return false;
5950 
5951   if (!Mask->getAPIntValue().isMask())
5952     return false;
5953 
5954   // No need to do anything if the and directly uses a load.
5955   if (isa<LoadSDNode>(N->getOperand(0)))
5956     return false;
5957 
5958   SmallVector<LoadSDNode*, 8> Loads;
5959   SmallPtrSet<SDNode*, 2> NodesWithConsts;
5960   SDNode *FixupNode = nullptr;
5961   if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
5962     if (Loads.size() == 0)
5963       return false;
5964 
5965     LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
5966     SDValue MaskOp = N->getOperand(1);
5967 
5968     // If it exists, fixup the single node we allow in the tree that needs
5969     // masking.
5970     if (FixupNode) {
5971       LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
5972       SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
5973                                 FixupNode->getValueType(0),
5974                                 SDValue(FixupNode, 0), MaskOp);
5975       DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
5976       if (And.getOpcode() == ISD ::AND)
5977         DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
5978     }
5979 
5980     // Narrow any constants that need it.
5981     for (auto *LogicN : NodesWithConsts) {
5982       SDValue Op0 = LogicN->getOperand(0);
5983       SDValue Op1 = LogicN->getOperand(1);
5984 
5985       if (isa<ConstantSDNode>(Op0))
5986           std::swap(Op0, Op1);
5987 
5988       SDValue And = DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(),
5989                                 Op1, MaskOp);
5990 
5991       DAG.UpdateNodeOperands(LogicN, Op0, And);
5992     }
5993 
5994     // Create narrow loads.
5995     for (auto *Load : Loads) {
5996       LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
5997       SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
5998                                 SDValue(Load, 0), MaskOp);
5999       DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
6000       if (And.getOpcode() == ISD ::AND)
6001         And = SDValue(
6002             DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
6003       SDValue NewLoad = reduceLoadWidth(And.getNode());
6004       assert(NewLoad &&
6005              "Shouldn't be masking the load if it can't be narrowed");
6006       CombineTo(Load, NewLoad, NewLoad.getValue(1));
6007     }
6008     DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
6009     return true;
6010   }
6011   return false;
6012 }
6013 
6014 // Unfold
6015 //    x &  (-1 'logical shift' y)
6016 // To
6017 //    (x 'opposite logical shift' y) 'logical shift' y
6018 // if it is better for performance.
unfoldExtremeBitClearingToShifts(SDNode * N)6019 SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
6020   assert(N->getOpcode() == ISD::AND);
6021 
6022   SDValue N0 = N->getOperand(0);
6023   SDValue N1 = N->getOperand(1);
6024 
6025   // Do we actually prefer shifts over mask?
6026   if (!TLI.shouldFoldMaskToVariableShiftPair(N0))
6027     return SDValue();
6028 
6029   // Try to match  (-1 '[outer] logical shift' y)
6030   unsigned OuterShift;
6031   unsigned InnerShift; // The opposite direction to the OuterShift.
6032   SDValue Y;           // Shift amount.
6033   auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
6034     if (!M.hasOneUse())
6035       return false;
6036     OuterShift = M->getOpcode();
6037     if (OuterShift == ISD::SHL)
6038       InnerShift = ISD::SRL;
6039     else if (OuterShift == ISD::SRL)
6040       InnerShift = ISD::SHL;
6041     else
6042       return false;
6043     if (!isAllOnesConstant(M->getOperand(0)))
6044       return false;
6045     Y = M->getOperand(1);
6046     return true;
6047   };
6048 
6049   SDValue X;
6050   if (matchMask(N1))
6051     X = N0;
6052   else if (matchMask(N0))
6053     X = N1;
6054   else
6055     return SDValue();
6056 
6057   SDLoc DL(N);
6058   EVT VT = N->getValueType(0);
6059 
6060   //     tmp = x   'opposite logical shift' y
6061   SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
6062   //     ret = tmp 'logical shift' y
6063   SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
6064 
6065   return T1;
6066 }
6067 
6068 /// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
6069 /// For a target with a bit test, this is expected to become test + set and save
6070 /// at least 1 instruction.
combineShiftAnd1ToBitTest(SDNode * And,SelectionDAG & DAG)6071 static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
6072   assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
6073 
6074   // This is probably not worthwhile without a supported type.
6075   EVT VT = And->getValueType(0);
6076   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6077   if (!TLI.isTypeLegal(VT))
6078     return SDValue();
6079 
6080   // Look through an optional extension.
6081   SDValue And0 = And->getOperand(0), And1 = And->getOperand(1);
6082   if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
6083     And0 = And0.getOperand(0);
6084   if (!isOneConstant(And1) || !And0.hasOneUse())
6085     return SDValue();
6086 
6087   SDValue Src = And0;
6088 
6089   // Attempt to find a 'not' op.
6090   // TODO: Should we favor test+set even without the 'not' op?
6091   bool FoundNot = false;
6092   if (isBitwiseNot(Src)) {
6093     FoundNot = true;
6094     Src = Src.getOperand(0);
6095 
6096     // Look though an optional truncation. The source operand may not be the
6097     // same type as the original 'and', but that is ok because we are masking
6098     // off everything but the low bit.
6099     if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
6100       Src = Src.getOperand(0);
6101   }
6102 
6103   // Match a shift-right by constant.
6104   if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
6105     return SDValue();
6106 
6107   // We might have looked through casts that make this transform invalid.
6108   // TODO: If the source type is wider than the result type, do the mask and
6109   //       compare in the source type.
6110   unsigned VTBitWidth = VT.getScalarSizeInBits();
6111   SDValue ShiftAmt = Src.getOperand(1);
6112   auto *ShiftAmtC = dyn_cast<ConstantSDNode>(ShiftAmt);
6113   if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(VTBitWidth))
6114     return SDValue();
6115 
6116   // Set source to shift source.
6117   Src = Src.getOperand(0);
6118 
6119   // Try again to find a 'not' op.
6120   // TODO: Should we favor test+set even with two 'not' ops?
6121   if (!FoundNot) {
6122     if (!isBitwiseNot(Src))
6123       return SDValue();
6124     Src = Src.getOperand(0);
6125   }
6126 
6127   if (!TLI.hasBitTest(Src, ShiftAmt))
6128     return SDValue();
6129 
6130   // Turn this into a bit-test pattern using mask op + setcc:
6131   // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
6132   // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
6133   SDLoc DL(And);
6134   SDValue X = DAG.getZExtOrTrunc(Src, DL, VT);
6135   EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
6136   SDValue Mask = DAG.getConstant(
6137       APInt::getOneBitSet(VTBitWidth, ShiftAmtC->getZExtValue()), DL, VT);
6138   SDValue NewAnd = DAG.getNode(ISD::AND, DL, VT, X, Mask);
6139   SDValue Zero = DAG.getConstant(0, DL, VT);
6140   SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
6141   return DAG.getZExtOrTrunc(Setcc, DL, VT);
6142 }
6143 
6144 /// For targets that support usubsat, match a bit-hack form of that operation
6145 /// that ends in 'and' and convert it.
foldAndToUsubsat(SDNode * N,SelectionDAG & DAG)6146 static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG) {
6147   SDValue N0 = N->getOperand(0);
6148   SDValue N1 = N->getOperand(1);
6149   EVT VT = N1.getValueType();
6150 
6151   // Canonicalize SRA as operand 1.
6152   if (N0.getOpcode() == ISD::SRA)
6153     std::swap(N0, N1);
6154 
6155   // xor/add with SMIN (signmask) are logically equivalent.
6156   if (N0.getOpcode() != ISD::XOR && N0.getOpcode() != ISD::ADD)
6157     return SDValue();
6158 
6159   if (N1.getOpcode() != ISD::SRA || !N0.hasOneUse() || !N1.hasOneUse() ||
6160       N0.getOperand(0) != N1.getOperand(0))
6161     return SDValue();
6162 
6163   unsigned BitWidth = VT.getScalarSizeInBits();
6164   ConstantSDNode *XorC = isConstOrConstSplat(N0.getOperand(1), true);
6165   ConstantSDNode *SraC = isConstOrConstSplat(N1.getOperand(1), true);
6166   if (!XorC || !XorC->getAPIntValue().isSignMask() ||
6167       !SraC || SraC->getAPIntValue() != BitWidth - 1)
6168     return SDValue();
6169 
6170   // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
6171   // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
6172   SDLoc DL(N);
6173   SDValue SignMask = DAG.getConstant(XorC->getAPIntValue(), DL, VT);
6174   return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0), SignMask);
6175 }
6176 
6177 /// Given a bitwise logic operation N with a matching bitwise logic operand,
6178 /// fold a pattern where 2 of the source operands are identically shifted
6179 /// values. For example:
6180 /// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
foldLogicOfShifts(SDNode * N,SDValue LogicOp,SDValue ShiftOp,SelectionDAG & DAG)6181 static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp,
6182                                  SelectionDAG &DAG) {
6183   unsigned LogicOpcode = N->getOpcode();
6184   assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
6185           LogicOpcode == ISD::XOR)
6186          && "Expected bitwise logic operation");
6187 
6188   if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
6189     return SDValue();
6190 
6191   // Match another bitwise logic op and a shift.
6192   unsigned ShiftOpcode = ShiftOp.getOpcode();
6193   if (LogicOp.getOpcode() != LogicOpcode ||
6194       !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
6195         ShiftOpcode == ISD::SRA))
6196     return SDValue();
6197 
6198   // Match another shift op inside the first logic operand. Handle both commuted
6199   // possibilities.
6200   // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6201   // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6202   SDValue X1 = ShiftOp.getOperand(0);
6203   SDValue Y = ShiftOp.getOperand(1);
6204   SDValue X0, Z;
6205   if (LogicOp.getOperand(0).getOpcode() == ShiftOpcode &&
6206       LogicOp.getOperand(0).getOperand(1) == Y) {
6207     X0 = LogicOp.getOperand(0).getOperand(0);
6208     Z = LogicOp.getOperand(1);
6209   } else if (LogicOp.getOperand(1).getOpcode() == ShiftOpcode &&
6210              LogicOp.getOperand(1).getOperand(1) == Y) {
6211     X0 = LogicOp.getOperand(1).getOperand(0);
6212     Z = LogicOp.getOperand(0);
6213   } else {
6214     return SDValue();
6215   }
6216 
6217   EVT VT = N->getValueType(0);
6218   SDLoc DL(N);
6219   SDValue LogicX = DAG.getNode(LogicOpcode, DL, VT, X0, X1);
6220   SDValue NewShift = DAG.getNode(ShiftOpcode, DL, VT, LogicX, Y);
6221   return DAG.getNode(LogicOpcode, DL, VT, NewShift, Z);
6222 }
6223 
6224 /// Given a tree of logic operations with shape like
6225 /// (LOGIC (LOGIC (X, Y), LOGIC (Z, Y)))
6226 /// try to match and fold shift operations with the same shift amount.
6227 /// For example:
6228 /// LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W) -->
6229 /// --> LOGIC (SH (LOGIC X0, X1), Y), (LOGIC Z, W)
foldLogicTreeOfShifts(SDNode * N,SDValue LeftHand,SDValue RightHand,SelectionDAG & DAG)6230 static SDValue foldLogicTreeOfShifts(SDNode *N, SDValue LeftHand,
6231                                      SDValue RightHand, SelectionDAG &DAG) {
6232   unsigned LogicOpcode = N->getOpcode();
6233   assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
6234           LogicOpcode == ISD::XOR));
6235   if (LeftHand.getOpcode() != LogicOpcode ||
6236       RightHand.getOpcode() != LogicOpcode)
6237     return SDValue();
6238   if (!LeftHand.hasOneUse() || !RightHand.hasOneUse())
6239     return SDValue();
6240 
6241   // Try to match one of following patterns:
6242   // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W)
6243   // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC W, (SH X1, Y))
6244   // Note that foldLogicOfShifts will handle commuted versions of the left hand
6245   // itself.
6246   SDValue CombinedShifts, W;
6247   SDValue R0 = RightHand.getOperand(0);
6248   SDValue R1 = RightHand.getOperand(1);
6249   if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R0, DAG)))
6250     W = R1;
6251   else if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R1, DAG)))
6252     W = R0;
6253   else
6254     return SDValue();
6255 
6256   EVT VT = N->getValueType(0);
6257   SDLoc DL(N);
6258   return DAG.getNode(LogicOpcode, DL, VT, CombinedShifts, W);
6259 }
6260 
visitAND(SDNode * N)6261 SDValue DAGCombiner::visitAND(SDNode *N) {
6262   SDValue N0 = N->getOperand(0);
6263   SDValue N1 = N->getOperand(1);
6264   EVT VT = N1.getValueType();
6265 
6266   // x & x --> x
6267   if (N0 == N1)
6268     return N0;
6269 
6270   // fold (and c1, c2) -> c1&c2
6271   if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, {N0, N1}))
6272     return C;
6273 
6274   // canonicalize constant to RHS
6275   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
6276       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
6277     return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0);
6278 
6279   // fold vector ops
6280   if (VT.isVector()) {
6281     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
6282       return FoldedVOp;
6283 
6284     // fold (and x, 0) -> 0, vector edition
6285     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
6286       // do not return N1, because undef node may exist in N1
6287       return DAG.getConstant(APInt::getZero(N1.getScalarValueSizeInBits()),
6288                              SDLoc(N), N1.getValueType());
6289 
6290     // fold (and x, -1) -> x, vector edition
6291     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
6292       return N0;
6293 
6294     // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
6295     auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
6296     ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true);
6297     if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat &&
6298         N1.hasOneUse()) {
6299       EVT LoadVT = MLoad->getMemoryVT();
6300       EVT ExtVT = VT;
6301       if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
6302         // For this AND to be a zero extension of the masked load the elements
6303         // of the BuildVec must mask the bottom bits of the extended element
6304         // type
6305         uint64_t ElementSize =
6306             LoadVT.getVectorElementType().getScalarSizeInBits();
6307         if (Splat->getAPIntValue().isMask(ElementSize)) {
6308           auto NewLoad = DAG.getMaskedLoad(
6309               ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(),
6310               MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(),
6311               LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
6312               ISD::ZEXTLOAD, MLoad->isExpandingLoad());
6313           bool LoadHasOtherUsers = !N0.hasOneUse();
6314           CombineTo(N, NewLoad);
6315           if (LoadHasOtherUsers)
6316             CombineTo(MLoad, NewLoad.getValue(0), NewLoad.getValue(1));
6317           return SDValue(N, 0);
6318         }
6319       }
6320     }
6321   }
6322 
6323   // fold (and x, -1) -> x
6324   if (isAllOnesConstant(N1))
6325     return N0;
6326 
6327   // if (and x, c) is known to be zero, return 0
6328   unsigned BitWidth = VT.getScalarSizeInBits();
6329   ConstantSDNode *N1C = isConstOrConstSplat(N1);
6330   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(BitWidth)))
6331     return DAG.getConstant(0, SDLoc(N), VT);
6332 
6333   if (SDValue NewSel = foldBinOpIntoSelect(N))
6334     return NewSel;
6335 
6336   // reassociate and
6337   if (SDValue RAND = reassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags()))
6338     return RAND;
6339 
6340   // fold (and (or x, C), D) -> D if (C & D) == D
6341   auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
6342     return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
6343   };
6344   if (N0.getOpcode() == ISD::OR &&
6345       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
6346     return N1;
6347 
6348   // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
6349   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
6350     SDValue N0Op0 = N0.getOperand(0);
6351     APInt Mask = ~N1C->getAPIntValue();
6352     Mask = Mask.trunc(N0Op0.getScalarValueSizeInBits());
6353     if (DAG.MaskedValueIsZero(N0Op0, Mask))
6354       return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N0.getValueType(), N0Op0);
6355   }
6356 
6357   // fold (and (ext (and V, c1)), c2) -> (and (ext V), (and c1, (ext c2)))
6358   if (ISD::isExtOpcode(N0.getOpcode())) {
6359     unsigned ExtOpc = N0.getOpcode();
6360     SDValue N0Op0 = N0.getOperand(0);
6361     if (N0Op0.getOpcode() == ISD::AND &&
6362         (ExtOpc != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0Op0, VT)) &&
6363         DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
6364         DAG.isConstantIntBuildVectorOrConstantInt(N0Op0.getOperand(1)) &&
6365         N0->hasOneUse() && N0Op0->hasOneUse()) {
6366       SDLoc DL(N);
6367       SDValue NewMask =
6368           DAG.getNode(ISD::AND, DL, VT, N1,
6369                       DAG.getNode(ExtOpc, DL, VT, N0Op0.getOperand(1)));
6370       return DAG.getNode(ISD::AND, DL, VT,
6371                          DAG.getNode(ExtOpc, DL, VT, N0Op0.getOperand(0)),
6372                          NewMask);
6373     }
6374   }
6375 
6376   // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
6377   // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
6378   // already be zero by virtue of the width of the base type of the load.
6379   //
6380   // the 'X' node here can either be nothing or an extract_vector_elt to catch
6381   // more cases.
6382   if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
6383        N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
6384        N0.getOperand(0).getOpcode() == ISD::LOAD &&
6385        N0.getOperand(0).getResNo() == 0) ||
6386       (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
6387     LoadSDNode *Load = cast<LoadSDNode>( (N0.getOpcode() == ISD::LOAD) ?
6388                                          N0 : N0.getOperand(0) );
6389 
6390     // Get the constant (if applicable) the zero'th operand is being ANDed with.
6391     // This can be a pure constant or a vector splat, in which case we treat the
6392     // vector as a scalar and use the splat value.
6393     APInt Constant = APInt::getZero(1);
6394     if (const ConstantSDNode *C = isConstOrConstSplat(
6395             N1, /*AllowUndef=*/false, /*AllowTruncation=*/true)) {
6396       Constant = C->getAPIntValue();
6397     } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
6398       APInt SplatValue, SplatUndef;
6399       unsigned SplatBitSize;
6400       bool HasAnyUndefs;
6401       bool IsSplat = Vector->isConstantSplat(SplatValue, SplatUndef,
6402                                              SplatBitSize, HasAnyUndefs);
6403       if (IsSplat) {
6404         // Undef bits can contribute to a possible optimisation if set, so
6405         // set them.
6406         SplatValue |= SplatUndef;
6407 
6408         // The splat value may be something like "0x00FFFFFF", which means 0 for
6409         // the first vector value and FF for the rest, repeating. We need a mask
6410         // that will apply equally to all members of the vector, so AND all the
6411         // lanes of the constant together.
6412         unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
6413 
6414         // If the splat value has been compressed to a bitlength lower
6415         // than the size of the vector lane, we need to re-expand it to
6416         // the lane size.
6417         if (EltBitWidth > SplatBitSize)
6418           for (SplatValue = SplatValue.zextOrTrunc(EltBitWidth);
6419                SplatBitSize < EltBitWidth; SplatBitSize = SplatBitSize * 2)
6420             SplatValue |= SplatValue.shl(SplatBitSize);
6421 
6422         // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
6423         // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
6424         if ((SplatBitSize % EltBitWidth) == 0) {
6425           Constant = APInt::getAllOnes(EltBitWidth);
6426           for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
6427             Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
6428         }
6429       }
6430     }
6431 
6432     // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
6433     // actually legal and isn't going to get expanded, else this is a false
6434     // optimisation.
6435     bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
6436                                                     Load->getValueType(0),
6437                                                     Load->getMemoryVT());
6438 
6439     // Resize the constant to the same size as the original memory access before
6440     // extension. If it is still the AllOnesValue then this AND is completely
6441     // unneeded.
6442     Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
6443 
6444     bool B;
6445     switch (Load->getExtensionType()) {
6446     default: B = false; break;
6447     case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
6448     case ISD::ZEXTLOAD:
6449     case ISD::NON_EXTLOAD: B = true; break;
6450     }
6451 
6452     if (B && Constant.isAllOnes()) {
6453       // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
6454       // preserve semantics once we get rid of the AND.
6455       SDValue NewLoad(Load, 0);
6456 
6457       // Fold the AND away. NewLoad may get replaced immediately.
6458       CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
6459 
6460       if (Load->getExtensionType() == ISD::EXTLOAD) {
6461         NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
6462                               Load->getValueType(0), SDLoc(Load),
6463                               Load->getChain(), Load->getBasePtr(),
6464                               Load->getOffset(), Load->getMemoryVT(),
6465                               Load->getMemOperand());
6466         // Replace uses of the EXTLOAD with the new ZEXTLOAD.
6467         if (Load->getNumValues() == 3) {
6468           // PRE/POST_INC loads have 3 values.
6469           SDValue To[] = { NewLoad.getValue(0), NewLoad.getValue(1),
6470                            NewLoad.getValue(2) };
6471           CombineTo(Load, To, 3, true);
6472         } else {
6473           CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
6474         }
6475       }
6476 
6477       return SDValue(N, 0); // Return N so it doesn't get rechecked!
6478     }
6479   }
6480 
6481   // Try to convert a constant mask AND into a shuffle clear mask.
6482   if (VT.isVector())
6483     if (SDValue Shuffle = XformToShuffleWithZero(N))
6484       return Shuffle;
6485 
6486   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
6487     return Combined;
6488 
6489   if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
6490       ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
6491     SDValue Ext = N0.getOperand(0);
6492     EVT ExtVT = Ext->getValueType(0);
6493     SDValue Extendee = Ext->getOperand(0);
6494 
6495     unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
6496     if (N1C->getAPIntValue().isMask(ScalarWidth) &&
6497         (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, ExtVT))) {
6498       //    (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
6499       // => (extract_subvector (iN_zeroext v))
6500       SDValue ZeroExtExtendee =
6501           DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), ExtVT, Extendee);
6502 
6503       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, ZeroExtExtendee,
6504                          N0.getOperand(1));
6505     }
6506   }
6507 
6508   // fold (and (masked_gather x)) -> (zext_masked_gather x)
6509   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
6510     EVT MemVT = GN0->getMemoryVT();
6511     EVT ScalarVT = MemVT.getScalarType();
6512 
6513     if (SDValue(GN0, 0).hasOneUse() &&
6514         isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) &&
6515         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
6516       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
6517                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
6518 
6519       SDValue ZExtLoad = DAG.getMaskedGather(
6520           DAG.getVTList(VT, MVT::Other), MemVT, SDLoc(N), Ops,
6521           GN0->getMemOperand(), GN0->getIndexType(), ISD::ZEXTLOAD);
6522 
6523       CombineTo(N, ZExtLoad);
6524       AddToWorklist(ZExtLoad.getNode());
6525       // Avoid recheck of N.
6526       return SDValue(N, 0);
6527     }
6528   }
6529 
6530   // fold (and (load x), 255) -> (zextload x, i8)
6531   // fold (and (extload x, i16), 255) -> (zextload x, i8)
6532   if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
6533     if (SDValue Res = reduceLoadWidth(N))
6534       return Res;
6535 
6536   if (LegalTypes) {
6537     // Attempt to propagate the AND back up to the leaves which, if they're
6538     // loads, can be combined to narrow loads and the AND node can be removed.
6539     // Perform after legalization so that extend nodes will already be
6540     // combined into the loads.
6541     if (BackwardsPropagateMask(N))
6542       return SDValue(N, 0);
6543   }
6544 
6545   if (SDValue Combined = visitANDLike(N0, N1, N))
6546     return Combined;
6547 
6548   // Simplify: (and (op x...), (op y...))  -> (op (and x, y))
6549   if (N0.getOpcode() == N1.getOpcode())
6550     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
6551       return V;
6552 
6553   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
6554     return R;
6555   if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
6556     return R;
6557 
6558   // Masking the negated extension of a boolean is just the zero-extended
6559   // boolean:
6560   // and (sub 0, zext(bool X)), 1 --> zext(bool X)
6561   // and (sub 0, sext(bool X)), 1 --> zext(bool X)
6562   //
6563   // Note: the SimplifyDemandedBits fold below can make an information-losing
6564   // transform, and then we have no way to find this better fold.
6565   if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
6566     if (isNullOrNullSplat(N0.getOperand(0))) {
6567       SDValue SubRHS = N0.getOperand(1);
6568       if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
6569           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
6570         return SubRHS;
6571       if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
6572           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
6573         return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, SubRHS.getOperand(0));
6574     }
6575   }
6576 
6577   // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
6578   // fold (and (sra)) -> (and (srl)) when possible.
6579   if (SimplifyDemandedBits(SDValue(N, 0)))
6580     return SDValue(N, 0);
6581 
6582   // fold (zext_inreg (extload x)) -> (zextload x)
6583   // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
6584   if (ISD::isUNINDEXEDLoad(N0.getNode()) &&
6585       (ISD::isEXTLoad(N0.getNode()) ||
6586        (ISD::isSEXTLoad(N0.getNode()) && N0.hasOneUse()))) {
6587     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
6588     EVT MemVT = LN0->getMemoryVT();
6589     // If we zero all the possible extended bits, then we can turn this into
6590     // a zextload if we are running before legalize or the operation is legal.
6591     unsigned ExtBitSize = N1.getScalarValueSizeInBits();
6592     unsigned MemBitSize = MemVT.getScalarSizeInBits();
6593     APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
6594     if (DAG.MaskedValueIsZero(N1, ExtBits) &&
6595         ((!LegalOperations && LN0->isSimple()) ||
6596          TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
6597       SDValue ExtLoad =
6598           DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
6599                          LN0->getBasePtr(), MemVT, LN0->getMemOperand());
6600       AddToWorklist(N);
6601       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
6602       return SDValue(N, 0); // Return N so it doesn't get rechecked!
6603     }
6604   }
6605 
6606   // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
6607   if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
6608     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
6609                                            N0.getOperand(1), false))
6610       return BSwap;
6611   }
6612 
6613   if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
6614     return Shifts;
6615 
6616   if (SDValue V = combineShiftAnd1ToBitTest(N, DAG))
6617     return V;
6618 
6619   // Recognize the following pattern:
6620   //
6621   // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
6622   //
6623   // where bitmask is a mask that clears the upper bits of AndVT. The
6624   // number of bits in bitmask must be a power of two.
6625   auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
6626     if (LHS->getOpcode() != ISD::SIGN_EXTEND)
6627       return false;
6628 
6629     auto *C = dyn_cast<ConstantSDNode>(RHS);
6630     if (!C)
6631       return false;
6632 
6633     if (!C->getAPIntValue().isMask(
6634             LHS.getOperand(0).getValueType().getFixedSizeInBits()))
6635       return false;
6636 
6637     return true;
6638   };
6639 
6640   // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
6641   if (IsAndZeroExtMask(N0, N1))
6642     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, N0.getOperand(0));
6643 
6644   if (hasOperation(ISD::USUBSAT, VT))
6645     if (SDValue V = foldAndToUsubsat(N, DAG))
6646       return V;
6647 
6648   // Postpone until legalization completed to avoid interference with bswap
6649   // folding
6650   if (LegalOperations || VT.isVector())
6651     if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
6652       return R;
6653 
6654   return SDValue();
6655 }
6656 
6657 /// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
MatchBSwapHWordLow(SDNode * N,SDValue N0,SDValue N1,bool DemandHighBits)6658 SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
6659                                         bool DemandHighBits) {
6660   if (!LegalOperations)
6661     return SDValue();
6662 
6663   EVT VT = N->getValueType(0);
6664   if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
6665     return SDValue();
6666   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
6667     return SDValue();
6668 
6669   // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
6670   bool LookPassAnd0 = false;
6671   bool LookPassAnd1 = false;
6672   if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
6673     std::swap(N0, N1);
6674   if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL)
6675     std::swap(N0, N1);
6676   if (N0.getOpcode() == ISD::AND) {
6677     if (!N0->hasOneUse())
6678       return SDValue();
6679     ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6680     // Also handle 0xffff since the LHS is guaranteed to have zeros there.
6681     // This is needed for X86.
6682     if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
6683                   N01C->getZExtValue() != 0xFFFF))
6684       return SDValue();
6685     N0 = N0.getOperand(0);
6686     LookPassAnd0 = true;
6687   }
6688 
6689   if (N1.getOpcode() == ISD::AND) {
6690     if (!N1->hasOneUse())
6691       return SDValue();
6692     ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
6693     if (!N11C || N11C->getZExtValue() != 0xFF)
6694       return SDValue();
6695     N1 = N1.getOperand(0);
6696     LookPassAnd1 = true;
6697   }
6698 
6699   if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
6700     std::swap(N0, N1);
6701   if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
6702     return SDValue();
6703   if (!N0->hasOneUse() || !N1->hasOneUse())
6704     return SDValue();
6705 
6706   ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6707   ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
6708   if (!N01C || !N11C)
6709     return SDValue();
6710   if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
6711     return SDValue();
6712 
6713   // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
6714   SDValue N00 = N0->getOperand(0);
6715   if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
6716     if (!N00->hasOneUse())
6717       return SDValue();
6718     ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1));
6719     if (!N001C || N001C->getZExtValue() != 0xFF)
6720       return SDValue();
6721     N00 = N00.getOperand(0);
6722     LookPassAnd0 = true;
6723   }
6724 
6725   SDValue N10 = N1->getOperand(0);
6726   if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
6727     if (!N10->hasOneUse())
6728       return SDValue();
6729     ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
6730     // Also allow 0xFFFF since the bits will be shifted out. This is needed
6731     // for X86.
6732     if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
6733                    N101C->getZExtValue() != 0xFFFF))
6734       return SDValue();
6735     N10 = N10.getOperand(0);
6736     LookPassAnd1 = true;
6737   }
6738 
6739   if (N00 != N10)
6740     return SDValue();
6741 
6742   // Make sure everything beyond the low halfword gets set to zero since the SRL
6743   // 16 will clear the top bits.
6744   unsigned OpSizeInBits = VT.getSizeInBits();
6745   if (OpSizeInBits > 16) {
6746     // If the left-shift isn't masked out then the only way this is a bswap is
6747     // if all bits beyond the low 8 are 0. In that case the entire pattern
6748     // reduces to a left shift anyway: leave it for other parts of the combiner.
6749     if (DemandHighBits && !LookPassAnd0)
6750       return SDValue();
6751 
6752     // However, if the right shift isn't masked out then it might be because
6753     // it's not needed. See if we can spot that too. If the high bits aren't
6754     // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
6755     // upper bits to be zero.
6756     if (!LookPassAnd1) {
6757       unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
6758       if (!DAG.MaskedValueIsZero(N10,
6759                                  APInt::getBitsSet(OpSizeInBits, 16, HighBit)))
6760         return SDValue();
6761     }
6762   }
6763 
6764   SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
6765   if (OpSizeInBits > 16) {
6766     SDLoc DL(N);
6767     Res = DAG.getNode(ISD::SRL, DL, VT, Res,
6768                       DAG.getConstant(OpSizeInBits - 16, DL,
6769                                       getShiftAmountTy(VT)));
6770   }
6771   return Res;
6772 }
6773 
6774 /// Return true if the specified node is an element that makes up a 32-bit
6775 /// packed halfword byteswap.
6776 /// ((x & 0x000000ff) << 8) |
6777 /// ((x & 0x0000ff00) >> 8) |
6778 /// ((x & 0x00ff0000) << 8) |
6779 /// ((x & 0xff000000) >> 8)
isBSwapHWordElement(SDValue N,MutableArrayRef<SDNode * > Parts)6780 static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
6781   if (!N->hasOneUse())
6782     return false;
6783 
6784   unsigned Opc = N.getOpcode();
6785   if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
6786     return false;
6787 
6788   SDValue N0 = N.getOperand(0);
6789   unsigned Opc0 = N0.getOpcode();
6790   if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
6791     return false;
6792 
6793   ConstantSDNode *N1C = nullptr;
6794   // SHL or SRL: look upstream for AND mask operand
6795   if (Opc == ISD::AND)
6796     N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6797   else if (Opc0 == ISD::AND)
6798     N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6799   if (!N1C)
6800     return false;
6801 
6802   unsigned MaskByteOffset;
6803   switch (N1C->getZExtValue()) {
6804   default:
6805     return false;
6806   case 0xFF:       MaskByteOffset = 0; break;
6807   case 0xFF00:     MaskByteOffset = 1; break;
6808   case 0xFFFF:
6809     // In case demanded bits didn't clear the bits that will be shifted out.
6810     // This is needed for X86.
6811     if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
6812       MaskByteOffset = 1;
6813       break;
6814     }
6815     return false;
6816   case 0xFF0000:   MaskByteOffset = 2; break;
6817   case 0xFF000000: MaskByteOffset = 3; break;
6818   }
6819 
6820   // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
6821   if (Opc == ISD::AND) {
6822     if (MaskByteOffset == 0 || MaskByteOffset == 2) {
6823       // (x >> 8) & 0xff
6824       // (x >> 8) & 0xff0000
6825       if (Opc0 != ISD::SRL)
6826         return false;
6827       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6828       if (!C || C->getZExtValue() != 8)
6829         return false;
6830     } else {
6831       // (x << 8) & 0xff00
6832       // (x << 8) & 0xff000000
6833       if (Opc0 != ISD::SHL)
6834         return false;
6835       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6836       if (!C || C->getZExtValue() != 8)
6837         return false;
6838     }
6839   } else if (Opc == ISD::SHL) {
6840     // (x & 0xff) << 8
6841     // (x & 0xff0000) << 8
6842     if (MaskByteOffset != 0 && MaskByteOffset != 2)
6843       return false;
6844     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6845     if (!C || C->getZExtValue() != 8)
6846       return false;
6847   } else { // Opc == ISD::SRL
6848     // (x & 0xff00) >> 8
6849     // (x & 0xff000000) >> 8
6850     if (MaskByteOffset != 1 && MaskByteOffset != 3)
6851       return false;
6852     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6853     if (!C || C->getZExtValue() != 8)
6854       return false;
6855   }
6856 
6857   if (Parts[MaskByteOffset])
6858     return false;
6859 
6860   Parts[MaskByteOffset] = N0.getOperand(0).getNode();
6861   return true;
6862 }
6863 
6864 // Match 2 elements of a packed halfword bswap.
isBSwapHWordPair(SDValue N,MutableArrayRef<SDNode * > Parts)6865 static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
6866   if (N.getOpcode() == ISD::OR)
6867     return isBSwapHWordElement(N.getOperand(0), Parts) &&
6868            isBSwapHWordElement(N.getOperand(1), Parts);
6869 
6870   if (N.getOpcode() == ISD::SRL && N.getOperand(0).getOpcode() == ISD::BSWAP) {
6871     ConstantSDNode *C = isConstOrConstSplat(N.getOperand(1));
6872     if (!C || C->getAPIntValue() != 16)
6873       return false;
6874     Parts[0] = Parts[1] = N.getOperand(0).getOperand(0).getNode();
6875     return true;
6876   }
6877 
6878   return false;
6879 }
6880 
6881 // Match this pattern:
6882 //   (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
6883 // And rewrite this to:
6884 //   (rotr (bswap A), 16)
matchBSwapHWordOrAndAnd(const TargetLowering & TLI,SelectionDAG & DAG,SDNode * N,SDValue N0,SDValue N1,EVT VT,EVT ShiftAmountTy)6885 static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
6886                                        SelectionDAG &DAG, SDNode *N, SDValue N0,
6887                                        SDValue N1, EVT VT, EVT ShiftAmountTy) {
6888   assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
6889          "MatchBSwapHWordOrAndAnd: expecting i32");
6890   if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
6891     return SDValue();
6892   if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
6893     return SDValue();
6894   // TODO: this is too restrictive; lifting this restriction requires more tests
6895   if (!N0->hasOneUse() || !N1->hasOneUse())
6896     return SDValue();
6897   ConstantSDNode *Mask0 = isConstOrConstSplat(N0.getOperand(1));
6898   ConstantSDNode *Mask1 = isConstOrConstSplat(N1.getOperand(1));
6899   if (!Mask0 || !Mask1)
6900     return SDValue();
6901   if (Mask0->getAPIntValue() != 0xff00ff00 ||
6902       Mask1->getAPIntValue() != 0x00ff00ff)
6903     return SDValue();
6904   SDValue Shift0 = N0.getOperand(0);
6905   SDValue Shift1 = N1.getOperand(0);
6906   if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
6907     return SDValue();
6908   ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(Shift0.getOperand(1));
6909   ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(Shift1.getOperand(1));
6910   if (!ShiftAmt0 || !ShiftAmt1)
6911     return SDValue();
6912   if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
6913     return SDValue();
6914   if (Shift0.getOperand(0) != Shift1.getOperand(0))
6915     return SDValue();
6916 
6917   SDLoc DL(N);
6918   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
6919   SDValue ShAmt = DAG.getConstant(16, DL, ShiftAmountTy);
6920   return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
6921 }
6922 
6923 /// Match a 32-bit packed halfword bswap. That is
6924 /// ((x & 0x000000ff) << 8) |
6925 /// ((x & 0x0000ff00) >> 8) |
6926 /// ((x & 0x00ff0000) << 8) |
6927 /// ((x & 0xff000000) >> 8)
6928 /// => (rotl (bswap x), 16)
MatchBSwapHWord(SDNode * N,SDValue N0,SDValue N1)6929 SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
6930   if (!LegalOperations)
6931     return SDValue();
6932 
6933   EVT VT = N->getValueType(0);
6934   if (VT != MVT::i32)
6935     return SDValue();
6936   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
6937     return SDValue();
6938 
6939   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT,
6940                                               getShiftAmountTy(VT)))
6941   return BSwap;
6942 
6943   // Try again with commuted operands.
6944   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT,
6945                                               getShiftAmountTy(VT)))
6946   return BSwap;
6947 
6948 
6949   // Look for either
6950   // (or (bswaphpair), (bswaphpair))
6951   // (or (or (bswaphpair), (and)), (and))
6952   // (or (or (and), (bswaphpair)), (and))
6953   SDNode *Parts[4] = {};
6954 
6955   if (isBSwapHWordPair(N0, Parts)) {
6956     // (or (or (and), (and)), (or (and), (and)))
6957     if (!isBSwapHWordPair(N1, Parts))
6958       return SDValue();
6959   } else if (N0.getOpcode() == ISD::OR) {
6960     // (or (or (or (and), (and)), (and)), (and))
6961     if (!isBSwapHWordElement(N1, Parts))
6962       return SDValue();
6963     SDValue N00 = N0.getOperand(0);
6964     SDValue N01 = N0.getOperand(1);
6965     if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) &&
6966         !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts)))
6967       return SDValue();
6968   } else {
6969     return SDValue();
6970   }
6971 
6972   // Make sure the parts are all coming from the same node.
6973   if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
6974     return SDValue();
6975 
6976   SDLoc DL(N);
6977   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT,
6978                               SDValue(Parts[0], 0));
6979 
6980   // Result of the bswap should be rotated by 16. If it's not legal, then
6981   // do  (x << 16) | (x >> 16).
6982   SDValue ShAmt = DAG.getConstant(16, DL, getShiftAmountTy(VT));
6983   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
6984     return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
6985   if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
6986     return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
6987   return DAG.getNode(ISD::OR, DL, VT,
6988                      DAG.getNode(ISD::SHL, DL, VT, BSwap, ShAmt),
6989                      DAG.getNode(ISD::SRL, DL, VT, BSwap, ShAmt));
6990 }
6991 
6992 /// This contains all DAGCombine rules which reduce two values combined by
6993 /// an Or operation to a single value \see visitANDLike().
visitORLike(SDValue N0,SDValue N1,SDNode * N)6994 SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) {
6995   EVT VT = N1.getValueType();
6996   SDLoc DL(N);
6997 
6998   // fold (or x, undef) -> -1
6999   if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
7000     return DAG.getAllOnesConstant(DL, VT);
7001 
7002   if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
7003     return V;
7004 
7005   // (or (and X, C1), (and Y, C2))  -> (and (or X, Y), C3) if possible.
7006   if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
7007       // Don't increase # computations.
7008       (N0->hasOneUse() || N1->hasOneUse())) {
7009     // We can only do this xform if we know that bits from X that are set in C2
7010     // but not in C1 are already zero.  Likewise for Y.
7011     if (const ConstantSDNode *N0O1C =
7012         getAsNonOpaqueConstant(N0.getOperand(1))) {
7013       if (const ConstantSDNode *N1O1C =
7014           getAsNonOpaqueConstant(N1.getOperand(1))) {
7015         // We can only do this xform if we know that bits from X that are set in
7016         // C2 but not in C1 are already zero.  Likewise for Y.
7017         const APInt &LHSMask = N0O1C->getAPIntValue();
7018         const APInt &RHSMask = N1O1C->getAPIntValue();
7019 
7020         if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) &&
7021             DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
7022           SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
7023                                   N0.getOperand(0), N1.getOperand(0));
7024           return DAG.getNode(ISD::AND, DL, VT, X,
7025                              DAG.getConstant(LHSMask | RHSMask, DL, VT));
7026         }
7027       }
7028     }
7029   }
7030 
7031   // (or (and X, M), (and X, N)) -> (and X, (or M, N))
7032   if (N0.getOpcode() == ISD::AND &&
7033       N1.getOpcode() == ISD::AND &&
7034       N0.getOperand(0) == N1.getOperand(0) &&
7035       // Don't increase # computations.
7036       (N0->hasOneUse() || N1->hasOneUse())) {
7037     SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
7038                             N0.getOperand(1), N1.getOperand(1));
7039     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
7040   }
7041 
7042   return SDValue();
7043 }
7044 
7045 /// OR combines for which the commuted variant will be tried as well.
visitORCommutative(SelectionDAG & DAG,SDValue N0,SDValue N1,SDNode * N)7046 static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
7047                                   SDNode *N) {
7048   EVT VT = N0.getValueType();
7049   if (N0.getOpcode() == ISD::AND) {
7050     SDValue N00 = N0.getOperand(0);
7051     SDValue N01 = N0.getOperand(1);
7052 
7053     // fold or (and x, y), x --> x
7054     if (N00 == N1 || N01 == N1)
7055       return N1;
7056 
7057     // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
7058     // TODO: Set AllowUndefs = true.
7059     if (getBitwiseNotOperand(N01, N00,
7060                              /* AllowUndefs */ false) == N1)
7061       return DAG.getNode(ISD::OR, SDLoc(N), VT, N00, N1);
7062 
7063     // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
7064     if (getBitwiseNotOperand(N00, N01,
7065                              /* AllowUndefs */ false) == N1)
7066       return DAG.getNode(ISD::OR, SDLoc(N), VT, N01, N1);
7067   }
7068 
7069   if (N0.getOpcode() == ISD::XOR) {
7070     // fold or (xor x, y), x --> or x, y
7071     //      or (xor x, y), (x and/or y) --> or x, y
7072     SDValue N00 = N0.getOperand(0);
7073     SDValue N01 = N0.getOperand(1);
7074     if (N00 == N1)
7075       return DAG.getNode(ISD::OR, SDLoc(N), VT, N01, N1);
7076     if (N01 == N1)
7077       return DAG.getNode(ISD::OR, SDLoc(N), VT, N00, N1);
7078 
7079     if (N1.getOpcode() == ISD::AND || N1.getOpcode() == ISD::OR) {
7080       SDValue N10 = N1.getOperand(0);
7081       SDValue N11 = N1.getOperand(1);
7082       if ((N00 == N10 && N01 == N11) || (N00 == N11 && N01 == N10))
7083         return DAG.getNode(ISD::OR, SDLoc(N), VT, N00, N01);
7084     }
7085   }
7086 
7087   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
7088     return R;
7089 
7090   auto peekThroughZext = [](SDValue V) {
7091     if (V->getOpcode() == ISD::ZERO_EXTEND)
7092       return V->getOperand(0);
7093     return V;
7094   };
7095 
7096   // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
7097   if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
7098       N0.getOperand(0) == N1.getOperand(0) &&
7099       peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
7100     return N0;
7101 
7102   // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
7103   if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
7104       N0.getOperand(1) == N1.getOperand(0) &&
7105       peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
7106     return N0;
7107 
7108   return SDValue();
7109 }
7110 
visitOR(SDNode * N)7111 SDValue DAGCombiner::visitOR(SDNode *N) {
7112   SDValue N0 = N->getOperand(0);
7113   SDValue N1 = N->getOperand(1);
7114   EVT VT = N1.getValueType();
7115 
7116   // x | x --> x
7117   if (N0 == N1)
7118     return N0;
7119 
7120   // fold (or c1, c2) -> c1|c2
7121   if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, {N0, N1}))
7122     return C;
7123 
7124   // canonicalize constant to RHS
7125   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
7126       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
7127     return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0);
7128 
7129   // fold vector ops
7130   if (VT.isVector()) {
7131     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
7132       return FoldedVOp;
7133 
7134     // fold (or x, 0) -> x, vector edition
7135     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
7136       return N0;
7137 
7138     // fold (or x, -1) -> -1, vector edition
7139     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
7140       // do not return N1, because undef node may exist in N1
7141       return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType());
7142 
7143     // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
7144     // Do this only if the resulting type / shuffle is legal.
7145     auto *SV0 = dyn_cast<ShuffleVectorSDNode>(N0);
7146     auto *SV1 = dyn_cast<ShuffleVectorSDNode>(N1);
7147     if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
7148       bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
7149       bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
7150       bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
7151       bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
7152       // Ensure both shuffles have a zero input.
7153       if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
7154         assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
7155         assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
7156         bool CanFold = true;
7157         int NumElts = VT.getVectorNumElements();
7158         SmallVector<int, 4> Mask(NumElts, -1);
7159 
7160         for (int i = 0; i != NumElts; ++i) {
7161           int M0 = SV0->getMaskElt(i);
7162           int M1 = SV1->getMaskElt(i);
7163 
7164           // Determine if either index is pointing to a zero vector.
7165           bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
7166           bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
7167 
7168           // If one element is zero and the otherside is undef, keep undef.
7169           // This also handles the case that both are undef.
7170           if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
7171             continue;
7172 
7173           // Make sure only one of the elements is zero.
7174           if (M0Zero == M1Zero) {
7175             CanFold = false;
7176             break;
7177           }
7178 
7179           assert((M0 >= 0 || M1 >= 0) && "Undef index!");
7180 
7181           // We have a zero and non-zero element. If the non-zero came from
7182           // SV0 make the index a LHS index. If it came from SV1, make it
7183           // a RHS index. We need to mod by NumElts because we don't care
7184           // which operand it came from in the original shuffles.
7185           Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
7186         }
7187 
7188         if (CanFold) {
7189           SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
7190           SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
7191 
7192           SDValue LegalShuffle =
7193               TLI.buildLegalVectorShuffle(VT, SDLoc(N), NewLHS, NewRHS,
7194                                           Mask, DAG);
7195           if (LegalShuffle)
7196             return LegalShuffle;
7197         }
7198       }
7199     }
7200   }
7201 
7202   // fold (or x, 0) -> x
7203   if (isNullConstant(N1))
7204     return N0;
7205 
7206   // fold (or x, -1) -> -1
7207   if (isAllOnesConstant(N1))
7208     return N1;
7209 
7210   if (SDValue NewSel = foldBinOpIntoSelect(N))
7211     return NewSel;
7212 
7213   // fold (or x, c) -> c iff (x & ~c) == 0
7214   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
7215   if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
7216     return N1;
7217 
7218   if (SDValue Combined = visitORLike(N0, N1, N))
7219     return Combined;
7220 
7221   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7222     return Combined;
7223 
7224   // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
7225   if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
7226     return BSwap;
7227   if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
7228     return BSwap;
7229 
7230   // reassociate or
7231   if (SDValue ROR = reassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags()))
7232     return ROR;
7233 
7234   // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
7235   // iff (c1 & c2) != 0 or c1/c2 are undef.
7236   auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
7237     return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
7238   };
7239   if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
7240       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
7241     if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT,
7242                                                  {N1, N0.getOperand(1)})) {
7243       SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
7244       AddToWorklist(IOR.getNode());
7245       return DAG.getNode(ISD::AND, SDLoc(N), VT, COR, IOR);
7246     }
7247   }
7248 
7249   if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
7250     return Combined;
7251   if (SDValue Combined = visitORCommutative(DAG, N1, N0, N))
7252     return Combined;
7253 
7254   // Simplify: (or (op x...), (op y...))  -> (op (or x, y))
7255   if (N0.getOpcode() == N1.getOpcode())
7256     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7257       return V;
7258 
7259   // See if this is some rotate idiom.
7260   if (SDValue Rot = MatchRotate(N0, N1, SDLoc(N)))
7261     return Rot;
7262 
7263   if (SDValue Load = MatchLoadCombine(N))
7264     return Load;
7265 
7266   // Simplify the operands using demanded-bits information.
7267   if (SimplifyDemandedBits(SDValue(N, 0)))
7268     return SDValue(N, 0);
7269 
7270   // If OR can be rewritten into ADD, try combines based on ADD.
7271   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
7272       DAG.haveNoCommonBitsSet(N0, N1))
7273     if (SDValue Combined = visitADDLike(N))
7274       return Combined;
7275 
7276   // Postpone until legalization completed to avoid interference with bswap
7277   // folding
7278   if (LegalOperations || VT.isVector())
7279     if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
7280       return R;
7281 
7282   return SDValue();
7283 }
7284 
stripConstantMask(const SelectionDAG & DAG,SDValue Op,SDValue & Mask)7285 static SDValue stripConstantMask(const SelectionDAG &DAG, SDValue Op,
7286                                  SDValue &Mask) {
7287   if (Op.getOpcode() == ISD::AND &&
7288       DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
7289     Mask = Op.getOperand(1);
7290     return Op.getOperand(0);
7291   }
7292   return Op;
7293 }
7294 
7295 /// Match "(X shl/srl V1) & V2" where V2 may not be present.
matchRotateHalf(const SelectionDAG & DAG,SDValue Op,SDValue & Shift,SDValue & Mask)7296 static bool matchRotateHalf(const SelectionDAG &DAG, SDValue Op, SDValue &Shift,
7297                             SDValue &Mask) {
7298   Op = stripConstantMask(DAG, Op, Mask);
7299   if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
7300     Shift = Op;
7301     return true;
7302   }
7303   return false;
7304 }
7305 
7306 /// Helper function for visitOR to extract the needed side of a rotate idiom
7307 /// from a shl/srl/mul/udiv.  This is meant to handle cases where
7308 /// InstCombine merged some outside op with one of the shifts from
7309 /// the rotate pattern.
7310 /// \returns An empty \c SDValue if the needed shift couldn't be extracted.
7311 /// Otherwise, returns an expansion of \p ExtractFrom based on the following
7312 /// patterns:
7313 ///
7314 ///   (or (add v v) (shrl v bitwidth-1)):
7315 ///     expands (add v v) -> (shl v 1)
7316 ///
7317 ///   (or (mul v c0) (shrl (mul v c1) c2)):
7318 ///     expands (mul v c0) -> (shl (mul v c1) c3)
7319 ///
7320 ///   (or (udiv v c0) (shl (udiv v c1) c2)):
7321 ///     expands (udiv v c0) -> (shrl (udiv v c1) c3)
7322 ///
7323 ///   (or (shl v c0) (shrl (shl v c1) c2)):
7324 ///     expands (shl v c0) -> (shl (shl v c1) c3)
7325 ///
7326 ///   (or (shrl v c0) (shl (shrl v c1) c2)):
7327 ///     expands (shrl v c0) -> (shrl (shrl v c1) c3)
7328 ///
7329 /// Such that in all cases, c3+c2==bitwidth(op v c1).
extractShiftForRotate(SelectionDAG & DAG,SDValue OppShift,SDValue ExtractFrom,SDValue & Mask,const SDLoc & DL)7330 static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
7331                                      SDValue ExtractFrom, SDValue &Mask,
7332                                      const SDLoc &DL) {
7333   assert(OppShift && ExtractFrom && "Empty SDValue");
7334   if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
7335     return SDValue();
7336 
7337   ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
7338 
7339   // Value and Type of the shift.
7340   SDValue OppShiftLHS = OppShift.getOperand(0);
7341   EVT ShiftedVT = OppShiftLHS.getValueType();
7342 
7343   // Amount of the existing shift.
7344   ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
7345 
7346   // (add v v) -> (shl v 1)
7347   // TODO: Should this be a general DAG canonicalization?
7348   if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
7349       ExtractFrom.getOpcode() == ISD::ADD &&
7350       ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) &&
7351       ExtractFrom.getOperand(0) == OppShiftLHS &&
7352       OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
7353     return DAG.getNode(ISD::SHL, DL, ShiftedVT, OppShiftLHS,
7354                        DAG.getShiftAmountConstant(1, ShiftedVT, DL));
7355 
7356   // Preconditions:
7357   //    (or (op0 v c0) (shiftl/r (op0 v c1) c2))
7358   //
7359   // Find opcode of the needed shift to be extracted from (op0 v c0).
7360   unsigned Opcode = ISD::DELETED_NODE;
7361   bool IsMulOrDiv = false;
7362   // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
7363   // opcode or its arithmetic (mul or udiv) variant.
7364   auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
7365     IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
7366     if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
7367       return false;
7368     Opcode = NeededShift;
7369     return true;
7370   };
7371   // op0 must be either the needed shift opcode or the mul/udiv equivalent
7372   // that the needed shift can be extracted from.
7373   if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
7374       (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
7375     return SDValue();
7376 
7377   // op0 must be the same opcode on both sides, have the same LHS argument,
7378   // and produce the same value type.
7379   if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
7380       OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
7381       ShiftedVT != ExtractFrom.getValueType())
7382     return SDValue();
7383 
7384   // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
7385   ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
7386   // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
7387   ConstantSDNode *ExtractFromCst =
7388       isConstOrConstSplat(ExtractFrom.getOperand(1));
7389   // TODO: We should be able to handle non-uniform constant vectors for these values
7390   // Check that we have constant values.
7391   if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
7392       !OppLHSCst || !OppLHSCst->getAPIntValue() ||
7393       !ExtractFromCst || !ExtractFromCst->getAPIntValue())
7394     return SDValue();
7395 
7396   // Compute the shift amount we need to extract to complete the rotate.
7397   const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
7398   if (OppShiftCst->getAPIntValue().ugt(VTWidth))
7399     return SDValue();
7400   APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
7401   // Normalize the bitwidth of the two mul/udiv/shift constant operands.
7402   APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
7403   APInt OppLHSAmt = OppLHSCst->getAPIntValue();
7404   zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
7405 
7406   // Now try extract the needed shift from the ExtractFrom op and see if the
7407   // result matches up with the existing shift's LHS op.
7408   if (IsMulOrDiv) {
7409     // Op to extract from is a mul or udiv by a constant.
7410     // Check:
7411     //     c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
7412     //     c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
7413     const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
7414                                                  NeededShiftAmt.getZExtValue());
7415     APInt ResultAmt;
7416     APInt Rem;
7417     APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
7418     if (Rem != 0 || ResultAmt != OppLHSAmt)
7419       return SDValue();
7420   } else {
7421     // Op to extract from is a shift by a constant.
7422     // Check:
7423     //      c2 - (bitwidth(op0 v c0) - c1) == c0
7424     if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
7425                                           ExtractFromAmt.getBitWidth()))
7426       return SDValue();
7427   }
7428 
7429   // Return the expanded shift op that should allow a rotate to be formed.
7430   EVT ShiftVT = OppShift.getOperand(1).getValueType();
7431   EVT ResVT = ExtractFrom.getValueType();
7432   SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
7433   return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
7434 }
7435 
7436 // Return true if we can prove that, whenever Neg and Pos are both in the
7437 // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos).  This means that
7438 // for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
7439 //
7440 //     (or (shift1 X, Neg), (shift2 X, Pos))
7441 //
7442 // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
7443 // in direction shift1 by Neg.  The range [0, EltSize) means that we only need
7444 // to consider shift amounts with defined behavior.
7445 //
7446 // The IsRotate flag should be set when the LHS of both shifts is the same.
7447 // Otherwise if matching a general funnel shift, it should be clear.
matchRotateSub(SDValue Pos,SDValue Neg,unsigned EltSize,SelectionDAG & DAG,bool IsRotate)7448 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
7449                            SelectionDAG &DAG, bool IsRotate) {
7450   const auto &TLI = DAG.getTargetLoweringInfo();
7451   // If EltSize is a power of 2 then:
7452   //
7453   //  (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
7454   //  (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
7455   //
7456   // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
7457   // for the stronger condition:
7458   //
7459   //     Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1)    [A]
7460   //
7461   // for all Neg and Pos.  Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
7462   // we can just replace Neg with Neg' for the rest of the function.
7463   //
7464   // In other cases we check for the even stronger condition:
7465   //
7466   //     Neg == EltSize - Pos                                    [B]
7467   //
7468   // for all Neg and Pos.  Note that the (or ...) then invokes undefined
7469   // behavior if Pos == 0 (and consequently Neg == EltSize).
7470   //
7471   // We could actually use [A] whenever EltSize is a power of 2, but the
7472   // only extra cases that it would match are those uninteresting ones
7473   // where Neg and Pos are never in range at the same time.  E.g. for
7474   // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
7475   // as well as (sub 32, Pos), but:
7476   //
7477   //     (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
7478   //
7479   // always invokes undefined behavior for 32-bit X.
7480   //
7481   // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
7482   // This allows us to peek through any operations that only affect Mask's
7483   // un-demanded bits.
7484   //
7485   // NOTE: We can only do this when matching operations which won't modify the
7486   // least Log2(EltSize) significant bits and not a general funnel shift.
7487   unsigned MaskLoBits = 0;
7488   if (IsRotate && isPowerOf2_64(EltSize)) {
7489     unsigned Bits = Log2_64(EltSize);
7490     unsigned NegBits = Neg.getScalarValueSizeInBits();
7491     if (NegBits >= Bits) {
7492       APInt DemandedBits = APInt::getLowBitsSet(NegBits, Bits);
7493       if (SDValue Inner =
7494               TLI.SimplifyMultipleUseDemandedBits(Neg, DemandedBits, DAG)) {
7495         Neg = Inner;
7496         MaskLoBits = Bits;
7497       }
7498     }
7499   }
7500 
7501   // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
7502   if (Neg.getOpcode() != ISD::SUB)
7503     return false;
7504   ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0));
7505   if (!NegC)
7506     return false;
7507   SDValue NegOp1 = Neg.getOperand(1);
7508 
7509   // On the RHS of [A], if Pos is the result of operation on Pos' that won't
7510   // affect Mask's demanded bits, just replace Pos with Pos'. These operations
7511   // are redundant for the purpose of the equality.
7512   if (MaskLoBits) {
7513     unsigned PosBits = Pos.getScalarValueSizeInBits();
7514     if (PosBits >= MaskLoBits) {
7515       APInt DemandedBits = APInt::getLowBitsSet(PosBits, MaskLoBits);
7516       if (SDValue Inner =
7517               TLI.SimplifyMultipleUseDemandedBits(Pos, DemandedBits, DAG)) {
7518         Pos = Inner;
7519       }
7520     }
7521   }
7522 
7523   // The condition we need is now:
7524   //
7525   //     (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
7526   //
7527   // If NegOp1 == Pos then we need:
7528   //
7529   //              EltSize & Mask == NegC & Mask
7530   //
7531   // (because "x & Mask" is a truncation and distributes through subtraction).
7532   //
7533   // We also need to account for a potential truncation of NegOp1 if the amount
7534   // has already been legalized to a shift amount type.
7535   APInt Width;
7536   if ((Pos == NegOp1) ||
7537       (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(0)))
7538     Width = NegC->getAPIntValue();
7539 
7540   // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
7541   // Then the condition we want to prove becomes:
7542   //
7543   //     (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
7544   //
7545   // which, again because "x & Mask" is a truncation, becomes:
7546   //
7547   //                NegC & Mask == (EltSize - PosC) & Mask
7548   //             EltSize & Mask == (NegC + PosC) & Mask
7549   else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
7550     if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
7551       Width = PosC->getAPIntValue() + NegC->getAPIntValue();
7552     else
7553       return false;
7554   } else
7555     return false;
7556 
7557   // Now we just need to check that EltSize & Mask == Width & Mask.
7558   if (MaskLoBits)
7559     // EltSize & Mask is 0 since Mask is EltSize - 1.
7560     return Width.getLoBits(MaskLoBits) == 0;
7561   return Width == EltSize;
7562 }
7563 
7564 // A subroutine of MatchRotate used once we have found an OR of two opposite
7565 // shifts of Shifted.  If Neg == <operand size> - Pos then the OR reduces
7566 // to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
7567 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
7568 // Neg with outer conversions stripped away.
MatchRotatePosNeg(SDValue Shifted,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)7569 SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
7570                                        SDValue Neg, SDValue InnerPos,
7571                                        SDValue InnerNeg, bool HasPos,
7572                                        unsigned PosOpcode, unsigned NegOpcode,
7573                                        const SDLoc &DL) {
7574   // fold (or (shl x, (*ext y)),
7575   //          (srl x, (*ext (sub 32, y)))) ->
7576   //   (rotl x, y) or (rotr x, (sub 32, y))
7577   //
7578   // fold (or (shl x, (*ext (sub 32, y))),
7579   //          (srl x, (*ext y))) ->
7580   //   (rotr x, y) or (rotl x, (sub 32, y))
7581   EVT VT = Shifted.getValueType();
7582   if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
7583                      /*IsRotate*/ true)) {
7584     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
7585                        HasPos ? Pos : Neg);
7586   }
7587 
7588   return SDValue();
7589 }
7590 
7591 // A subroutine of MatchRotate used once we have found an OR of two opposite
7592 // shifts of N0 + N1.  If Neg == <operand size> - Pos then the OR reduces
7593 // to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
7594 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
7595 // Neg with outer conversions stripped away.
7596 // TODO: Merge with MatchRotatePosNeg.
MatchFunnelPosNeg(SDValue N0,SDValue N1,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)7597 SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
7598                                        SDValue Neg, SDValue InnerPos,
7599                                        SDValue InnerNeg, bool HasPos,
7600                                        unsigned PosOpcode, unsigned NegOpcode,
7601                                        const SDLoc &DL) {
7602   EVT VT = N0.getValueType();
7603   unsigned EltBits = VT.getScalarSizeInBits();
7604 
7605   // fold (or (shl x0, (*ext y)),
7606   //          (srl x1, (*ext (sub 32, y)))) ->
7607   //   (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
7608   //
7609   // fold (or (shl x0, (*ext (sub 32, y))),
7610   //          (srl x1, (*ext y))) ->
7611   //   (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
7612   if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) {
7613     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
7614                        HasPos ? Pos : Neg);
7615   }
7616 
7617   // Matching the shift+xor cases, we can't easily use the xor'd shift amount
7618   // so for now just use the PosOpcode case if its legal.
7619   // TODO: When can we use the NegOpcode case?
7620   if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) {
7621     auto IsBinOpImm = [](SDValue Op, unsigned BinOpc, unsigned Imm) {
7622       if (Op.getOpcode() != BinOpc)
7623         return false;
7624       ConstantSDNode *Cst = isConstOrConstSplat(Op.getOperand(1));
7625       return Cst && (Cst->getAPIntValue() == Imm);
7626     };
7627 
7628     // fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
7629     //   -> (fshl x0, x1, y)
7630     if (IsBinOpImm(N1, ISD::SRL, 1) &&
7631         IsBinOpImm(InnerNeg, ISD::XOR, EltBits - 1) &&
7632         InnerPos == InnerNeg.getOperand(0) &&
7633         TLI.isOperationLegalOrCustom(ISD::FSHL, VT)) {
7634       return DAG.getNode(ISD::FSHL, DL, VT, N0, N1.getOperand(0), Pos);
7635     }
7636 
7637     // fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
7638     //   -> (fshr x0, x1, y)
7639     if (IsBinOpImm(N0, ISD::SHL, 1) &&
7640         IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
7641         InnerNeg == InnerPos.getOperand(0) &&
7642         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
7643       return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
7644     }
7645 
7646     // fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y))
7647     //   -> (fshr x0, x1, y)
7648     // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
7649     if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N0.getOperand(1) &&
7650         IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
7651         InnerNeg == InnerPos.getOperand(0) &&
7652         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
7653       return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
7654     }
7655   }
7656 
7657   return SDValue();
7658 }
7659 
7660 // MatchRotate - Handle an 'or' of two operands.  If this is one of the many
7661 // idioms for rotate, and if the target supports rotation instructions, generate
7662 // a rot[lr]. This also matches funnel shift patterns, similar to rotation but
7663 // with different shifted sources.
MatchRotate(SDValue LHS,SDValue RHS,const SDLoc & DL)7664 SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
7665   EVT VT = LHS.getValueType();
7666 
7667   // The target must have at least one rotate/funnel flavor.
7668   // We still try to match rotate by constant pre-legalization.
7669   // TODO: Support pre-legalization funnel-shift by constant.
7670   bool HasROTL = hasOperation(ISD::ROTL, VT);
7671   bool HasROTR = hasOperation(ISD::ROTR, VT);
7672   bool HasFSHL = hasOperation(ISD::FSHL, VT);
7673   bool HasFSHR = hasOperation(ISD::FSHR, VT);
7674 
7675   // If the type is going to be promoted and the target has enabled custom
7676   // lowering for rotate, allow matching rotate by non-constants. Only allow
7677   // this for scalar types.
7678   if (VT.isScalarInteger() && TLI.getTypeAction(*DAG.getContext(), VT) ==
7679                                   TargetLowering::TypePromoteInteger) {
7680     HasROTL |= TLI.getOperationAction(ISD::ROTL, VT) == TargetLowering::Custom;
7681     HasROTR |= TLI.getOperationAction(ISD::ROTR, VT) == TargetLowering::Custom;
7682   }
7683 
7684   if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
7685     return SDValue();
7686 
7687   // Check for truncated rotate.
7688   if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
7689       LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
7690     assert(LHS.getValueType() == RHS.getValueType());
7691     if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
7692       return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
7693     }
7694   }
7695 
7696   // Match "(X shl/srl V1) & V2" where V2 may not be present.
7697   SDValue LHSShift;   // The shift.
7698   SDValue LHSMask;    // AND value if any.
7699   matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
7700 
7701   SDValue RHSShift;   // The shift.
7702   SDValue RHSMask;    // AND value if any.
7703   matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
7704 
7705   // If neither side matched a rotate half, bail
7706   if (!LHSShift && !RHSShift)
7707     return SDValue();
7708 
7709   // InstCombine may have combined a constant shl, srl, mul, or udiv with one
7710   // side of the rotate, so try to handle that here. In all cases we need to
7711   // pass the matched shift from the opposite side to compute the opcode and
7712   // needed shift amount to extract.  We still want to do this if both sides
7713   // matched a rotate half because one half may be a potential overshift that
7714   // can be broken down (ie if InstCombine merged two shl or srl ops into a
7715   // single one).
7716 
7717   // Have LHS side of the rotate, try to extract the needed shift from the RHS.
7718   if (LHSShift)
7719     if (SDValue NewRHSShift =
7720             extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
7721       RHSShift = NewRHSShift;
7722   // Have RHS side of the rotate, try to extract the needed shift from the LHS.
7723   if (RHSShift)
7724     if (SDValue NewLHSShift =
7725             extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
7726       LHSShift = NewLHSShift;
7727 
7728   // If a side is still missing, nothing else we can do.
7729   if (!RHSShift || !LHSShift)
7730     return SDValue();
7731 
7732   // At this point we've matched or extracted a shift op on each side.
7733 
7734   if (LHSShift.getOpcode() == RHSShift.getOpcode())
7735     return SDValue(); // Shifts must disagree.
7736 
7737   // Canonicalize shl to left side in a shl/srl pair.
7738   if (RHSShift.getOpcode() == ISD::SHL) {
7739     std::swap(LHS, RHS);
7740     std::swap(LHSShift, RHSShift);
7741     std::swap(LHSMask, RHSMask);
7742   }
7743 
7744   // Something has gone wrong - we've lost the shl/srl pair - bail.
7745   if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
7746     return SDValue();
7747 
7748   unsigned EltSizeInBits = VT.getScalarSizeInBits();
7749   SDValue LHSShiftArg = LHSShift.getOperand(0);
7750   SDValue LHSShiftAmt = LHSShift.getOperand(1);
7751   SDValue RHSShiftArg = RHSShift.getOperand(0);
7752   SDValue RHSShiftAmt = RHSShift.getOperand(1);
7753 
7754   auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
7755                                         ConstantSDNode *RHS) {
7756     return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
7757   };
7758 
7759   auto ApplyMasks = [&](SDValue Res) {
7760     // If there is an AND of either shifted operand, apply it to the result.
7761     if (LHSMask.getNode() || RHSMask.getNode()) {
7762       SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
7763       SDValue Mask = AllOnes;
7764 
7765       if (LHSMask.getNode()) {
7766         SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
7767         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
7768                            DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
7769       }
7770       if (RHSMask.getNode()) {
7771         SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
7772         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
7773                            DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
7774       }
7775 
7776       Res = DAG.getNode(ISD::AND, DL, VT, Res, Mask);
7777     }
7778 
7779     return Res;
7780   };
7781 
7782   // TODO: Support pre-legalization funnel-shift by constant.
7783   bool IsRotate = LHSShiftArg == RHSShiftArg;
7784   if (!IsRotate && !(HasFSHL || HasFSHR)) {
7785     if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
7786         ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
7787       // Look for a disguised rotate by constant.
7788       // The common shifted operand X may be hidden inside another 'or'.
7789       SDValue X, Y;
7790       auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
7791         if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
7792           return false;
7793         if (CommonOp == Or.getOperand(0)) {
7794           X = CommonOp;
7795           Y = Or.getOperand(1);
7796           return true;
7797         }
7798         if (CommonOp == Or.getOperand(1)) {
7799           X = CommonOp;
7800           Y = Or.getOperand(0);
7801           return true;
7802         }
7803         return false;
7804       };
7805 
7806       SDValue Res;
7807       if (matchOr(LHSShiftArg, RHSShiftArg)) {
7808         // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
7809         SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
7810         SDValue ShlY = DAG.getNode(ISD::SHL, DL, VT, Y, LHSShiftAmt);
7811         Res = DAG.getNode(ISD::OR, DL, VT, RotX, ShlY);
7812       } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
7813         // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
7814         SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
7815         SDValue SrlY = DAG.getNode(ISD::SRL, DL, VT, Y, RHSShiftAmt);
7816         Res = DAG.getNode(ISD::OR, DL, VT, RotX, SrlY);
7817       } else {
7818         return SDValue();
7819       }
7820 
7821       return ApplyMasks(Res);
7822     }
7823 
7824     return SDValue(); // Requires funnel shift support.
7825   }
7826 
7827   // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
7828   // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
7829   // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
7830   // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
7831   // iff C1+C2 == EltSizeInBits
7832   if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
7833     SDValue Res;
7834     if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
7835       bool UseROTL = !LegalOperations || HasROTL;
7836       Res = DAG.getNode(UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg,
7837                         UseROTL ? LHSShiftAmt : RHSShiftAmt);
7838     } else {
7839       bool UseFSHL = !LegalOperations || HasFSHL;
7840       Res = DAG.getNode(UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg,
7841                         RHSShiftArg, UseFSHL ? LHSShiftAmt : RHSShiftAmt);
7842     }
7843 
7844     return ApplyMasks(Res);
7845   }
7846 
7847   // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
7848   // shift.
7849   if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
7850     return SDValue();
7851 
7852   // If there is a mask here, and we have a variable shift, we can't be sure
7853   // that we're masking out the right stuff.
7854   if (LHSMask.getNode() || RHSMask.getNode())
7855     return SDValue();
7856 
7857   // If the shift amount is sign/zext/any-extended just peel it off.
7858   SDValue LExtOp0 = LHSShiftAmt;
7859   SDValue RExtOp0 = RHSShiftAmt;
7860   if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
7861        LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
7862        LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
7863        LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
7864       (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
7865        RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
7866        RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
7867        RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
7868     LExtOp0 = LHSShiftAmt.getOperand(0);
7869     RExtOp0 = RHSShiftAmt.getOperand(0);
7870   }
7871 
7872   if (IsRotate && (HasROTL || HasROTR)) {
7873     SDValue TryL =
7874         MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0,
7875                           RExtOp0, HasROTL, ISD::ROTL, ISD::ROTR, DL);
7876     if (TryL)
7877       return TryL;
7878 
7879     SDValue TryR =
7880         MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0,
7881                           LExtOp0, HasROTR, ISD::ROTR, ISD::ROTL, DL);
7882     if (TryR)
7883       return TryR;
7884   }
7885 
7886   SDValue TryL =
7887       MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt,
7888                         LExtOp0, RExtOp0, HasFSHL, ISD::FSHL, ISD::FSHR, DL);
7889   if (TryL)
7890     return TryL;
7891 
7892   SDValue TryR =
7893       MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
7894                         RExtOp0, LExtOp0, HasFSHR, ISD::FSHR, ISD::FSHL, DL);
7895   if (TryR)
7896     return TryR;
7897 
7898   return SDValue();
7899 }
7900 
7901 namespace {
7902 
7903 /// Represents known origin of an individual byte in load combine pattern. The
7904 /// value of the byte is either constant zero or comes from memory.
7905 struct ByteProvider {
7906   // For constant zero providers Load is set to nullptr. For memory providers
7907   // Load represents the node which loads the byte from memory.
7908   // ByteOffset is the offset of the byte in the value produced by the load.
7909   LoadSDNode *Load = nullptr;
7910   unsigned ByteOffset = 0;
7911   unsigned VectorOffset = 0;
7912 
7913   ByteProvider() = default;
7914 
getMemory__anonbd6f1c501611::ByteProvider7915   static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset,
7916                                 unsigned VectorOffset) {
7917     return ByteProvider(Load, ByteOffset, VectorOffset);
7918   }
7919 
getConstantZero__anonbd6f1c501611::ByteProvider7920   static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0, 0); }
7921 
isConstantZero__anonbd6f1c501611::ByteProvider7922   bool isConstantZero() const { return !Load; }
isMemory__anonbd6f1c501611::ByteProvider7923   bool isMemory() const { return Load; }
7924 
operator ==__anonbd6f1c501611::ByteProvider7925   bool operator==(const ByteProvider &Other) const {
7926     return Other.Load == Load && Other.ByteOffset == ByteOffset &&
7927            Other.VectorOffset == VectorOffset;
7928   }
7929 
7930 private:
ByteProvider__anonbd6f1c501611::ByteProvider7931   ByteProvider(LoadSDNode *Load, unsigned ByteOffset, unsigned VectorOffset)
7932       : Load(Load), ByteOffset(ByteOffset), VectorOffset(VectorOffset) {}
7933 };
7934 
7935 } // end anonymous namespace
7936 
7937 /// Recursively traverses the expression calculating the origin of the requested
7938 /// byte of the given value. Returns std::nullopt if the provider can't be
7939 /// calculated.
7940 ///
7941 /// For all the values except the root of the expression, we verify that the
7942 /// value has exactly one use and if not then return std::nullopt. This way if
7943 /// the origin of the byte is returned it's guaranteed that the values which
7944 /// contribute to the byte are not used outside of this expression.
7945 
7946 /// However, there is a special case when dealing with vector loads -- we allow
7947 /// more than one use if the load is a vector type.  Since the values that
7948 /// contribute to the byte ultimately come from the ExtractVectorElements of the
7949 /// Load, we don't care if the Load has uses other than ExtractVectorElements,
7950 /// because those operations are independent from the pattern to be combined.
7951 /// For vector loads, we simply care that the ByteProviders are adjacent
7952 /// positions of the same vector, and their index matches the byte that is being
7953 /// provided. This is captured by the \p VectorIndex algorithm. \p VectorIndex
7954 /// is the index used in an ExtractVectorElement, and \p StartingIndex is the
7955 /// byte position we are trying to provide for the LoadCombine. If these do
7956 /// not match, then we can not combine the vector loads. \p Index uses the
7957 /// byte position we are trying to provide for and is matched against the
7958 /// shl and load size. The \p Index algorithm ensures the requested byte is
7959 /// provided for by the pattern, and the pattern does not over provide bytes.
7960 ///
7961 ///
7962 /// The supported LoadCombine pattern for vector loads is as follows
7963 ///                              or
7964 ///                          /        \
7965 ///                         or        shl
7966 ///                       /     \      |
7967 ///                     or      shl   zext
7968 ///                   /    \     |     |
7969 ///                 shl   zext  zext  EVE*
7970 ///                  |     |     |     |
7971 ///                 zext  EVE*  EVE*  LOAD
7972 ///                  |     |     |
7973 ///                 EVE*  LOAD  LOAD
7974 ///                  |
7975 ///                 LOAD
7976 ///
7977 /// *ExtractVectorElement
7978 static const std::optional<ByteProvider>
calculateByteProvider(SDValue Op,unsigned Index,unsigned Depth,std::optional<uint64_t> VectorIndex,unsigned StartingIndex=0)7979 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
7980                       std::optional<uint64_t> VectorIndex,
7981                       unsigned StartingIndex = 0) {
7982 
7983   // Typical i64 by i8 pattern requires recursion up to 8 calls depth
7984   if (Depth == 10)
7985     return std::nullopt;
7986 
7987   // Only allow multiple uses if the instruction is a vector load (in which
7988   // case we will use the load for every ExtractVectorElement)
7989   if (Depth && !Op.hasOneUse() &&
7990       (Op.getOpcode() != ISD::LOAD || !Op.getValueType().isVector()))
7991     return std::nullopt;
7992 
7993   // Fail to combine if we have encountered anything but a LOAD after handling
7994   // an ExtractVectorElement.
7995   if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value())
7996     return std::nullopt;
7997 
7998   unsigned BitWidth = Op.getValueSizeInBits();
7999   if (BitWidth % 8 != 0)
8000     return std::nullopt;
8001   unsigned ByteWidth = BitWidth / 8;
8002   assert(Index < ByteWidth && "invalid index requested");
8003   (void) ByteWidth;
8004 
8005   switch (Op.getOpcode()) {
8006   case ISD::OR: {
8007     auto LHS =
8008         calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex);
8009     if (!LHS)
8010       return std::nullopt;
8011     auto RHS =
8012         calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex);
8013     if (!RHS)
8014       return std::nullopt;
8015 
8016     if (LHS->isConstantZero())
8017       return RHS;
8018     if (RHS->isConstantZero())
8019       return LHS;
8020     return std::nullopt;
8021   }
8022   case ISD::SHL: {
8023     auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
8024     if (!ShiftOp)
8025       return std::nullopt;
8026 
8027     uint64_t BitShift = ShiftOp->getZExtValue();
8028 
8029     if (BitShift % 8 != 0)
8030       return std::nullopt;
8031     uint64_t ByteShift = BitShift / 8;
8032 
8033     // If we are shifting by an amount greater than the index we are trying to
8034     // provide, then do not provide anything. Otherwise, subtract the index by
8035     // the amount we shifted by.
8036     return Index < ByteShift
8037                ? ByteProvider::getConstantZero()
8038                : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
8039                                        Depth + 1, VectorIndex, Index);
8040   }
8041   case ISD::ANY_EXTEND:
8042   case ISD::SIGN_EXTEND:
8043   case ISD::ZERO_EXTEND: {
8044     SDValue NarrowOp = Op->getOperand(0);
8045     unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8046     if (NarrowBitWidth % 8 != 0)
8047       return std::nullopt;
8048     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8049 
8050     if (Index >= NarrowByteWidth)
8051       return Op.getOpcode() == ISD::ZERO_EXTEND
8052                  ? std::optional<ByteProvider>(ByteProvider::getConstantZero())
8053                  : std::nullopt;
8054     return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
8055                                  StartingIndex);
8056   }
8057   case ISD::BSWAP:
8058     return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
8059                                  Depth + 1, VectorIndex, StartingIndex);
8060   case ISD::EXTRACT_VECTOR_ELT: {
8061     auto OffsetOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
8062     if (!OffsetOp)
8063       return std::nullopt;
8064 
8065     VectorIndex = OffsetOp->getZExtValue();
8066 
8067     SDValue NarrowOp = Op->getOperand(0);
8068     unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8069     if (NarrowBitWidth % 8 != 0)
8070       return std::nullopt;
8071     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8072 
8073     // Check to see if the position of the element in the vector corresponds
8074     // with the byte we are trying to provide for. In the case of a vector of
8075     // i8, this simply means the VectorIndex == StartingIndex. For non i8 cases,
8076     // the element will provide a range of bytes. For example, if we have a
8077     // vector of i16s, each element provides two bytes (V[1] provides byte 2 and
8078     // 3).
8079     if (*VectorIndex * NarrowByteWidth > StartingIndex)
8080       return std::nullopt;
8081     if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
8082       return std::nullopt;
8083 
8084     return calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
8085                                  VectorIndex, StartingIndex);
8086   }
8087   case ISD::LOAD: {
8088     auto L = cast<LoadSDNode>(Op.getNode());
8089     if (!L->isSimple() || L->isIndexed())
8090       return std::nullopt;
8091 
8092     unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
8093     if (NarrowBitWidth % 8 != 0)
8094       return std::nullopt;
8095     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8096 
8097     // If the width of the load does not reach byte we are trying to provide for
8098     // and it is not a ZEXTLOAD, then the load does not provide for the byte in
8099     // question
8100     if (Index >= NarrowByteWidth)
8101       return L->getExtensionType() == ISD::ZEXTLOAD
8102                  ? std::optional<ByteProvider>(ByteProvider::getConstantZero())
8103                  : std::nullopt;
8104 
8105     unsigned BPVectorIndex = VectorIndex.value_or(0U);
8106     return ByteProvider::getMemory(L, Index, BPVectorIndex);
8107   }
8108   }
8109 
8110   return std::nullopt;
8111 }
8112 
littleEndianByteAt(unsigned BW,unsigned i)8113 static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
8114   return i;
8115 }
8116 
bigEndianByteAt(unsigned BW,unsigned i)8117 static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
8118   return BW - i - 1;
8119 }
8120 
8121 // Check if the bytes offsets we are looking at match with either big or
8122 // little endian value loaded. Return true for big endian, false for little
8123 // endian, and std::nullopt if match failed.
isBigEndian(const ArrayRef<int64_t> ByteOffsets,int64_t FirstOffset)8124 static std::optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
8125                                        int64_t FirstOffset) {
8126   // The endian can be decided only when it is 2 bytes at least.
8127   unsigned Width = ByteOffsets.size();
8128   if (Width < 2)
8129     return std::nullopt;
8130 
8131   bool BigEndian = true, LittleEndian = true;
8132   for (unsigned i = 0; i < Width; i++) {
8133     int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
8134     LittleEndian &= CurrentByteOffset == littleEndianByteAt(Width, i);
8135     BigEndian &= CurrentByteOffset == bigEndianByteAt(Width, i);
8136     if (!BigEndian && !LittleEndian)
8137       return std::nullopt;
8138   }
8139 
8140   assert((BigEndian != LittleEndian) && "It should be either big endian or"
8141                                         "little endian");
8142   return BigEndian;
8143 }
8144 
stripTruncAndExt(SDValue Value)8145 static SDValue stripTruncAndExt(SDValue Value) {
8146   switch (Value.getOpcode()) {
8147   case ISD::TRUNCATE:
8148   case ISD::ZERO_EXTEND:
8149   case ISD::SIGN_EXTEND:
8150   case ISD::ANY_EXTEND:
8151     return stripTruncAndExt(Value.getOperand(0));
8152   }
8153   return Value;
8154 }
8155 
8156 /// Match a pattern where a wide type scalar value is stored by several narrow
8157 /// stores. Fold it into a single store or a BSWAP and a store if the targets
8158 /// supports it.
8159 ///
8160 /// Assuming little endian target:
8161 ///  i8 *p = ...
8162 ///  i32 val = ...
8163 ///  p[0] = (val >> 0) & 0xFF;
8164 ///  p[1] = (val >> 8) & 0xFF;
8165 ///  p[2] = (val >> 16) & 0xFF;
8166 ///  p[3] = (val >> 24) & 0xFF;
8167 /// =>
8168 ///  *((i32)p) = val;
8169 ///
8170 ///  i8 *p = ...
8171 ///  i32 val = ...
8172 ///  p[0] = (val >> 24) & 0xFF;
8173 ///  p[1] = (val >> 16) & 0xFF;
8174 ///  p[2] = (val >> 8) & 0xFF;
8175 ///  p[3] = (val >> 0) & 0xFF;
8176 /// =>
8177 ///  *((i32)p) = BSWAP(val);
mergeTruncStores(StoreSDNode * N)8178 SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
8179   // The matching looks for "store (trunc x)" patterns that appear early but are
8180   // likely to be replaced by truncating store nodes during combining.
8181   // TODO: If there is evidence that running this later would help, this
8182   //       limitation could be removed. Legality checks may need to be added
8183   //       for the created store and optional bswap/rotate.
8184   if (LegalOperations || OptLevel == CodeGenOpt::None)
8185     return SDValue();
8186 
8187   // We only handle merging simple stores of 1-4 bytes.
8188   // TODO: Allow unordered atomics when wider type is legal (see D66309)
8189   EVT MemVT = N->getMemoryVT();
8190   if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
8191       !N->isSimple() || N->isIndexed())
8192     return SDValue();
8193 
8194   // Collect all of the stores in the chain.
8195   SDValue Chain = N->getChain();
8196   SmallVector<StoreSDNode *, 8> Stores = {N};
8197   while (auto *Store = dyn_cast<StoreSDNode>(Chain)) {
8198     // All stores must be the same size to ensure that we are writing all of the
8199     // bytes in the wide value.
8200     // This store should have exactly one use as a chain operand for another
8201     // store in the merging set. If there are other chain uses, then the
8202     // transform may not be safe because order of loads/stores outside of this
8203     // set may not be preserved.
8204     // TODO: We could allow multiple sizes by tracking each stored byte.
8205     if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
8206         Store->isIndexed() || !Store->hasOneUse())
8207       return SDValue();
8208     Stores.push_back(Store);
8209     Chain = Store->getChain();
8210   }
8211   // There is no reason to continue if we do not have at least a pair of stores.
8212   if (Stores.size() < 2)
8213     return SDValue();
8214 
8215   // Handle simple types only.
8216   LLVMContext &Context = *DAG.getContext();
8217   unsigned NumStores = Stores.size();
8218   unsigned NarrowNumBits = N->getMemoryVT().getScalarSizeInBits();
8219   unsigned WideNumBits = NumStores * NarrowNumBits;
8220   EVT WideVT = EVT::getIntegerVT(Context, WideNumBits);
8221   if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
8222     return SDValue();
8223 
8224   // Check if all bytes of the source value that we are looking at are stored
8225   // to the same base address. Collect offsets from Base address into OffsetMap.
8226   SDValue SourceValue;
8227   SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
8228   int64_t FirstOffset = INT64_MAX;
8229   StoreSDNode *FirstStore = nullptr;
8230   std::optional<BaseIndexOffset> Base;
8231   for (auto *Store : Stores) {
8232     // All the stores store different parts of the CombinedValue. A truncate is
8233     // required to get the partial value.
8234     SDValue Trunc = Store->getValue();
8235     if (Trunc.getOpcode() != ISD::TRUNCATE)
8236       return SDValue();
8237     // Other than the first/last part, a shift operation is required to get the
8238     // offset.
8239     int64_t Offset = 0;
8240     SDValue WideVal = Trunc.getOperand(0);
8241     if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
8242         isa<ConstantSDNode>(WideVal.getOperand(1))) {
8243       // The shift amount must be a constant multiple of the narrow type.
8244       // It is translated to the offset address in the wide source value "y".
8245       //
8246       // x = srl y, ShiftAmtC
8247       // i8 z = trunc x
8248       // store z, ...
8249       uint64_t ShiftAmtC = WideVal.getConstantOperandVal(1);
8250       if (ShiftAmtC % NarrowNumBits != 0)
8251         return SDValue();
8252 
8253       Offset = ShiftAmtC / NarrowNumBits;
8254       WideVal = WideVal.getOperand(0);
8255     }
8256 
8257     // Stores must share the same source value with different offsets.
8258     // Truncate and extends should be stripped to get the single source value.
8259     if (!SourceValue)
8260       SourceValue = WideVal;
8261     else if (stripTruncAndExt(SourceValue) != stripTruncAndExt(WideVal))
8262       return SDValue();
8263     else if (SourceValue.getValueType() != WideVT) {
8264       if (WideVal.getValueType() == WideVT ||
8265           WideVal.getScalarValueSizeInBits() >
8266               SourceValue.getScalarValueSizeInBits())
8267         SourceValue = WideVal;
8268       // Give up if the source value type is smaller than the store size.
8269       if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
8270         return SDValue();
8271     }
8272 
8273     // Stores must share the same base address.
8274     BaseIndexOffset Ptr = BaseIndexOffset::match(Store, DAG);
8275     int64_t ByteOffsetFromBase = 0;
8276     if (!Base)
8277       Base = Ptr;
8278     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
8279       return SDValue();
8280 
8281     // Remember the first store.
8282     if (ByteOffsetFromBase < FirstOffset) {
8283       FirstStore = Store;
8284       FirstOffset = ByteOffsetFromBase;
8285     }
8286     // Map the offset in the store and the offset in the combined value, and
8287     // early return if it has been set before.
8288     if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
8289       return SDValue();
8290     OffsetMap[Offset] = ByteOffsetFromBase;
8291   }
8292 
8293   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
8294   assert(FirstStore && "First store must be set");
8295 
8296   // Check that a store of the wide type is both allowed and fast on the target
8297   const DataLayout &Layout = DAG.getDataLayout();
8298   unsigned Fast = 0;
8299   bool Allowed = TLI.allowsMemoryAccess(Context, Layout, WideVT,
8300                                         *FirstStore->getMemOperand(), &Fast);
8301   if (!Allowed || !Fast)
8302     return SDValue();
8303 
8304   // Check if the pieces of the value are going to the expected places in memory
8305   // to merge the stores.
8306   auto checkOffsets = [&](bool MatchLittleEndian) {
8307     if (MatchLittleEndian) {
8308       for (unsigned i = 0; i != NumStores; ++i)
8309         if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
8310           return false;
8311     } else { // MatchBigEndian by reversing loop counter.
8312       for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
8313         if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
8314           return false;
8315     }
8316     return true;
8317   };
8318 
8319   // Check if the offsets line up for the native data layout of this target.
8320   bool NeedBswap = false;
8321   bool NeedRotate = false;
8322   if (!checkOffsets(Layout.isLittleEndian())) {
8323     // Special-case: check if byte offsets line up for the opposite endian.
8324     if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
8325       NeedBswap = true;
8326     else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
8327       NeedRotate = true;
8328     else
8329       return SDValue();
8330   }
8331 
8332   SDLoc DL(N);
8333   if (WideVT != SourceValue.getValueType()) {
8334     assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
8335            "Unexpected store value to merge");
8336     SourceValue = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SourceValue);
8337   }
8338 
8339   // Before legalize we can introduce illegal bswaps/rotates which will be later
8340   // converted to an explicit bswap sequence. This way we end up with a single
8341   // store and byte shuffling instead of several stores and byte shuffling.
8342   if (NeedBswap) {
8343     SourceValue = DAG.getNode(ISD::BSWAP, DL, WideVT, SourceValue);
8344   } else if (NeedRotate) {
8345     assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
8346     SDValue RotAmt = DAG.getConstant(WideNumBits / 2, DL, WideVT);
8347     SourceValue = DAG.getNode(ISD::ROTR, DL, WideVT, SourceValue, RotAmt);
8348   }
8349 
8350   SDValue NewStore =
8351       DAG.getStore(Chain, DL, SourceValue, FirstStore->getBasePtr(),
8352                    FirstStore->getPointerInfo(), FirstStore->getAlign());
8353 
8354   // Rely on other DAG combine rules to remove the other individual stores.
8355   DAG.ReplaceAllUsesWith(N, NewStore.getNode());
8356   return NewStore;
8357 }
8358 
8359 /// Match a pattern where a wide type scalar value is loaded by several narrow
8360 /// loads and combined by shifts and ors. Fold it into a single load or a load
8361 /// and a BSWAP if the targets supports it.
8362 ///
8363 /// Assuming little endian target:
8364 ///  i8 *a = ...
8365 ///  i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
8366 /// =>
8367 ///  i32 val = *((i32)a)
8368 ///
8369 ///  i8 *a = ...
8370 ///  i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
8371 /// =>
8372 ///  i32 val = BSWAP(*((i32)a))
8373 ///
8374 /// TODO: This rule matches complex patterns with OR node roots and doesn't
8375 /// interact well with the worklist mechanism. When a part of the pattern is
8376 /// updated (e.g. one of the loads) its direct users are put into the worklist,
8377 /// but the root node of the pattern which triggers the load combine is not
8378 /// necessarily a direct user of the changed node. For example, once the address
8379 /// of t28 load is reassociated load combine won't be triggered:
8380 ///             t25: i32 = add t4, Constant:i32<2>
8381 ///           t26: i64 = sign_extend t25
8382 ///        t27: i64 = add t2, t26
8383 ///       t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
8384 ///     t29: i32 = zero_extend t28
8385 ///   t32: i32 = shl t29, Constant:i8<8>
8386 /// t33: i32 = or t23, t32
8387 /// As a possible fix visitLoad can check if the load can be a part of a load
8388 /// combine pattern and add corresponding OR roots to the worklist.
MatchLoadCombine(SDNode * N)8389 SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
8390   assert(N->getOpcode() == ISD::OR &&
8391          "Can only match load combining against OR nodes");
8392 
8393   // Handles simple types only
8394   EVT VT = N->getValueType(0);
8395   if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
8396     return SDValue();
8397   unsigned ByteWidth = VT.getSizeInBits() / 8;
8398 
8399   bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
8400   auto MemoryByteOffset = [&] (ByteProvider P) {
8401     assert(P.isMemory() && "Must be a memory byte provider");
8402     unsigned LoadBitWidth = P.Load->getMemoryVT().getScalarSizeInBits();
8403 
8404     assert(LoadBitWidth % 8 == 0 &&
8405            "can only analyze providers for individual bytes not bit");
8406     unsigned LoadByteWidth = LoadBitWidth / 8;
8407     return IsBigEndianTarget
8408             ? bigEndianByteAt(LoadByteWidth, P.ByteOffset)
8409             : littleEndianByteAt(LoadByteWidth, P.ByteOffset);
8410   };
8411 
8412   std::optional<BaseIndexOffset> Base;
8413   SDValue Chain;
8414 
8415   SmallPtrSet<LoadSDNode *, 8> Loads;
8416   std::optional<ByteProvider> FirstByteProvider;
8417   int64_t FirstOffset = INT64_MAX;
8418 
8419   // Check if all the bytes of the OR we are looking at are loaded from the same
8420   // base address. Collect bytes offsets from Base address in ByteOffsets.
8421   SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
8422   unsigned ZeroExtendedBytes = 0;
8423   for (int i = ByteWidth - 1; i >= 0; --i) {
8424     auto P =
8425         calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt,
8426                               /*StartingIndex*/ i);
8427     if (!P)
8428       return SDValue();
8429 
8430     if (P->isConstantZero()) {
8431       // It's OK for the N most significant bytes to be 0, we can just
8432       // zero-extend the load.
8433       if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
8434         return SDValue();
8435       continue;
8436     }
8437     assert(P->isMemory() && "provenance should either be memory or zero");
8438 
8439     LoadSDNode *L = P->Load;
8440 
8441     // All loads must share the same chain
8442     SDValue LChain = L->getChain();
8443     if (!Chain)
8444       Chain = LChain;
8445     else if (Chain != LChain)
8446       return SDValue();
8447 
8448     // Loads must share the same base address
8449     BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
8450     int64_t ByteOffsetFromBase = 0;
8451 
8452     // For vector loads, the expected load combine pattern will have an
8453     // ExtractElement for each index in the vector. While each of these
8454     // ExtractElements will be accessing the same base address as determined
8455     // by the load instruction, the actual bytes they interact with will differ
8456     // due to different ExtractElement indices. To accurately determine the
8457     // byte position of an ExtractElement, we offset the base load ptr with
8458     // the index multiplied by the byte size of each element in the vector.
8459     if (L->getMemoryVT().isVector()) {
8460       unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
8461       if (LoadWidthInBit % 8 != 0)
8462         return SDValue();
8463       unsigned ByteOffsetFromVector = P->VectorOffset * LoadWidthInBit / 8;
8464       Ptr.addToOffset(ByteOffsetFromVector);
8465     }
8466 
8467     if (!Base)
8468       Base = Ptr;
8469 
8470     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
8471       return SDValue();
8472 
8473     // Calculate the offset of the current byte from the base address
8474     ByteOffsetFromBase += MemoryByteOffset(*P);
8475     ByteOffsets[i] = ByteOffsetFromBase;
8476 
8477     // Remember the first byte load
8478     if (ByteOffsetFromBase < FirstOffset) {
8479       FirstByteProvider = P;
8480       FirstOffset = ByteOffsetFromBase;
8481     }
8482 
8483     Loads.insert(L);
8484   }
8485 
8486   assert(!Loads.empty() && "All the bytes of the value must be loaded from "
8487          "memory, so there must be at least one load which produces the value");
8488   assert(Base && "Base address of the accessed memory location must be set");
8489   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
8490 
8491   bool NeedsZext = ZeroExtendedBytes > 0;
8492 
8493   EVT MemVT =
8494       EVT::getIntegerVT(*DAG.getContext(), (ByteWidth - ZeroExtendedBytes) * 8);
8495 
8496   if (!MemVT.isSimple())
8497     return SDValue();
8498 
8499   // Before legalize we can introduce too wide illegal loads which will be later
8500   // split into legal sized loads. This enables us to combine i64 load by i8
8501   // patterns to a couple of i32 loads on 32 bit targets.
8502   if (LegalOperations &&
8503       !TLI.isOperationLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
8504                             MemVT))
8505     return SDValue();
8506 
8507   // Check if the bytes of the OR we are looking at match with either big or
8508   // little endian value load
8509   std::optional<bool> IsBigEndian = isBigEndian(
8510       ArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset);
8511   if (!IsBigEndian)
8512     return SDValue();
8513 
8514   assert(FirstByteProvider && "must be set");
8515 
8516   // Ensure that the first byte is loaded from zero offset of the first load.
8517   // So the combined value can be loaded from the first load address.
8518   if (MemoryByteOffset(*FirstByteProvider) != 0)
8519     return SDValue();
8520   LoadSDNode *FirstLoad = FirstByteProvider->Load;
8521 
8522   // The node we are looking at matches with the pattern, check if we can
8523   // replace it with a single (possibly zero-extended) load and bswap + shift if
8524   // needed.
8525 
8526   // If the load needs byte swap check if the target supports it
8527   bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
8528 
8529   // Before legalize we can introduce illegal bswaps which will be later
8530   // converted to an explicit bswap sequence. This way we end up with a single
8531   // load and byte shuffling instead of several loads and byte shuffling.
8532   // We do not introduce illegal bswaps when zero-extending as this tends to
8533   // introduce too many arithmetic instructions.
8534   if (NeedsBswap && (LegalOperations || NeedsZext) &&
8535       !TLI.isOperationLegal(ISD::BSWAP, VT))
8536     return SDValue();
8537 
8538   // If we need to bswap and zero extend, we have to insert a shift. Check that
8539   // it is legal.
8540   if (NeedsBswap && NeedsZext && LegalOperations &&
8541       !TLI.isOperationLegal(ISD::SHL, VT))
8542     return SDValue();
8543 
8544   // Check that a load of the wide type is both allowed and fast on the target
8545   unsigned Fast = 0;
8546   bool Allowed =
8547       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
8548                              *FirstLoad->getMemOperand(), &Fast);
8549   if (!Allowed || !Fast)
8550     return SDValue();
8551 
8552   SDValue NewLoad =
8553       DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
8554                      Chain, FirstLoad->getBasePtr(),
8555                      FirstLoad->getPointerInfo(), MemVT, FirstLoad->getAlign());
8556 
8557   // Transfer chain users from old loads to the new load.
8558   for (LoadSDNode *L : Loads)
8559     DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1));
8560 
8561   if (!NeedsBswap)
8562     return NewLoad;
8563 
8564   SDValue ShiftedLoad =
8565       NeedsZext
8566           ? DAG.getNode(ISD::SHL, SDLoc(N), VT, NewLoad,
8567                         DAG.getShiftAmountConstant(ZeroExtendedBytes * 8, VT,
8568                                                    SDLoc(N), LegalOperations))
8569           : NewLoad;
8570   return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
8571 }
8572 
8573 // If the target has andn, bsl, or a similar bit-select instruction,
8574 // we want to unfold masked merge, with canonical pattern of:
8575 //   |        A  |  |B|
8576 //   ((x ^ y) & m) ^ y
8577 //    |  D  |
8578 // Into:
8579 //   (x & m) | (y & ~m)
8580 // If y is a constant, m is not a 'not', and the 'andn' does not work with
8581 // immediates, we unfold into a different pattern:
8582 //   ~(~x & m) & (m | y)
8583 // If x is a constant, m is a 'not', and the 'andn' does not work with
8584 // immediates, we unfold into a different pattern:
8585 //   (x | ~m) & ~(~m & ~y)
8586 // NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
8587 //       the very least that breaks andnpd / andnps patterns, and because those
8588 //       patterns are simplified in IR and shouldn't be created in the DAG
unfoldMaskedMerge(SDNode * N)8589 SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
8590   assert(N->getOpcode() == ISD::XOR);
8591 
8592   // Don't touch 'not' (i.e. where y = -1).
8593   if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
8594     return SDValue();
8595 
8596   EVT VT = N->getValueType(0);
8597 
8598   // There are 3 commutable operators in the pattern,
8599   // so we have to deal with 8 possible variants of the basic pattern.
8600   SDValue X, Y, M;
8601   auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
8602     if (And.getOpcode() != ISD::AND || !And.hasOneUse())
8603       return false;
8604     SDValue Xor = And.getOperand(XorIdx);
8605     if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
8606       return false;
8607     SDValue Xor0 = Xor.getOperand(0);
8608     SDValue Xor1 = Xor.getOperand(1);
8609     // Don't touch 'not' (i.e. where y = -1).
8610     if (isAllOnesOrAllOnesSplat(Xor1))
8611       return false;
8612     if (Other == Xor0)
8613       std::swap(Xor0, Xor1);
8614     if (Other != Xor1)
8615       return false;
8616     X = Xor0;
8617     Y = Xor1;
8618     M = And.getOperand(XorIdx ? 0 : 1);
8619     return true;
8620   };
8621 
8622   SDValue N0 = N->getOperand(0);
8623   SDValue N1 = N->getOperand(1);
8624   if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
8625       !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
8626     return SDValue();
8627 
8628   // Don't do anything if the mask is constant. This should not be reachable.
8629   // InstCombine should have already unfolded this pattern, and DAGCombiner
8630   // probably shouldn't produce it, too.
8631   if (isa<ConstantSDNode>(M.getNode()))
8632     return SDValue();
8633 
8634   // We can transform if the target has AndNot
8635   if (!TLI.hasAndNot(M))
8636     return SDValue();
8637 
8638   SDLoc DL(N);
8639 
8640   // If Y is a constant, check that 'andn' works with immediates. Unless M is
8641   // a bitwise not that would already allow ANDN to be used.
8642   if (!TLI.hasAndNot(Y) && !isBitwiseNot(M)) {
8643     assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
8644     // If not, we need to do a bit more work to make sure andn is still used.
8645     SDValue NotX = DAG.getNOT(DL, X, VT);
8646     SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
8647     SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
8648     SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
8649     return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
8650   }
8651 
8652   // If X is a constant and M is a bitwise not, check that 'andn' works with
8653   // immediates.
8654   if (!TLI.hasAndNot(X) && isBitwiseNot(M)) {
8655     assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
8656     // If not, we need to do a bit more work to make sure andn is still used.
8657     SDValue NotM = M.getOperand(0);
8658     SDValue LHS = DAG.getNode(ISD::OR, DL, VT, X, NotM);
8659     SDValue NotY = DAG.getNOT(DL, Y, VT);
8660     SDValue RHS = DAG.getNode(ISD::AND, DL, VT, NotM, NotY);
8661     SDValue NotRHS = DAG.getNOT(DL, RHS, VT);
8662     return DAG.getNode(ISD::AND, DL, VT, LHS, NotRHS);
8663   }
8664 
8665   SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
8666   SDValue NotM = DAG.getNOT(DL, M, VT);
8667   SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
8668 
8669   return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
8670 }
8671 
visitXOR(SDNode * N)8672 SDValue DAGCombiner::visitXOR(SDNode *N) {
8673   SDValue N0 = N->getOperand(0);
8674   SDValue N1 = N->getOperand(1);
8675   EVT VT = N0.getValueType();
8676   SDLoc DL(N);
8677 
8678   // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
8679   if (N0.isUndef() && N1.isUndef())
8680     return DAG.getConstant(0, DL, VT);
8681 
8682   // fold (xor x, undef) -> undef
8683   if (N0.isUndef())
8684     return N0;
8685   if (N1.isUndef())
8686     return N1;
8687 
8688   // fold (xor c1, c2) -> c1^c2
8689   if (SDValue C = DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, {N0, N1}))
8690     return C;
8691 
8692   // canonicalize constant to RHS
8693   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
8694       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
8695     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
8696 
8697   // fold vector ops
8698   if (VT.isVector()) {
8699     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
8700       return FoldedVOp;
8701 
8702     // fold (xor x, 0) -> x, vector edition
8703     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
8704       return N0;
8705   }
8706 
8707   // fold (xor x, 0) -> x
8708   if (isNullConstant(N1))
8709     return N0;
8710 
8711   if (SDValue NewSel = foldBinOpIntoSelect(N))
8712     return NewSel;
8713 
8714   // reassociate xor
8715   if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
8716     return RXOR;
8717 
8718   // fold (a^b) -> (a|b) iff a and b share no bits.
8719   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
8720       DAG.haveNoCommonBitsSet(N0, N1))
8721     return DAG.getNode(ISD::OR, DL, VT, N0, N1);
8722 
8723   // look for 'add-like' folds:
8724   // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
8725   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
8726       isMinSignedConstant(N1))
8727     if (SDValue Combined = visitADDLike(N))
8728       return Combined;
8729 
8730   // fold !(x cc y) -> (x !cc y)
8731   unsigned N0Opcode = N0.getOpcode();
8732   SDValue LHS, RHS, CC;
8733   if (TLI.isConstTrueVal(N1) &&
8734       isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/ true)) {
8735     ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
8736                                                LHS.getValueType());
8737     if (!LegalOperations ||
8738         TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
8739       switch (N0Opcode) {
8740       default:
8741         llvm_unreachable("Unhandled SetCC Equivalent!");
8742       case ISD::SETCC:
8743         return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
8744       case ISD::SELECT_CC:
8745         return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
8746                                N0.getOperand(3), NotCC);
8747       case ISD::STRICT_FSETCC:
8748       case ISD::STRICT_FSETCCS: {
8749         if (N0.hasOneUse()) {
8750           // FIXME Can we handle multiple uses? Could we token factor the chain
8751           // results from the new/old setcc?
8752           SDValue SetCC =
8753               DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC,
8754                            N0.getOperand(0), N0Opcode == ISD::STRICT_FSETCCS);
8755           CombineTo(N, SetCC);
8756           DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), SetCC.getValue(1));
8757           recursivelyDeleteUnusedNodes(N0.getNode());
8758           return SDValue(N, 0); // Return N so it doesn't get rechecked!
8759         }
8760         break;
8761       }
8762       }
8763     }
8764   }
8765 
8766   // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
8767   if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
8768       isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
8769     SDValue V = N0.getOperand(0);
8770     SDLoc DL0(N0);
8771     V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
8772                     DAG.getConstant(1, DL0, V.getValueType()));
8773     AddToWorklist(V.getNode());
8774     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
8775   }
8776 
8777   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
8778   if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
8779       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
8780     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
8781     if (isOneUseSetCC(N01) || isOneUseSetCC(N00)) {
8782       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
8783       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
8784       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
8785       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
8786       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
8787     }
8788   }
8789   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
8790   if (isAllOnesConstant(N1) && N0.hasOneUse() &&
8791       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
8792     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
8793     if (isa<ConstantSDNode>(N01) || isa<ConstantSDNode>(N00)) {
8794       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
8795       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
8796       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
8797       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
8798       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
8799     }
8800   }
8801 
8802   // fold (not (neg x)) -> (add X, -1)
8803   // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
8804   // Y is a constant or the subtract has a single use.
8805   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::SUB &&
8806       isNullConstant(N0.getOperand(0))) {
8807     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
8808                        DAG.getAllOnesConstant(DL, VT));
8809   }
8810 
8811   // fold (not (add X, -1)) -> (neg X)
8812   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::ADD &&
8813       isAllOnesOrAllOnesSplat(N0.getOperand(1))) {
8814     return DAG.getNegative(N0.getOperand(0), DL, VT);
8815   }
8816 
8817   // fold (xor (and x, y), y) -> (and (not x), y)
8818   if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
8819     SDValue X = N0.getOperand(0);
8820     SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
8821     AddToWorklist(NotX.getNode());
8822     return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
8823   }
8824 
8825   // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
8826   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
8827     SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
8828     SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
8829     if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
8830       SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
8831       SDValue S0 = S.getOperand(0);
8832       if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
8833         if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
8834           if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
8835             return DAG.getNode(ISD::ABS, DL, VT, S0);
8836     }
8837   }
8838 
8839   // fold (xor x, x) -> 0
8840   if (N0 == N1)
8841     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
8842 
8843   // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
8844   // Here is a concrete example of this equivalence:
8845   // i16   x ==  14
8846   // i16 shl ==   1 << 14  == 16384 == 0b0100000000000000
8847   // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
8848   //
8849   // =>
8850   //
8851   // i16     ~1      == 0b1111111111111110
8852   // i16 rol(~1, 14) == 0b1011111111111111
8853   //
8854   // Some additional tips to help conceptualize this transform:
8855   // - Try to see the operation as placing a single zero in a value of all ones.
8856   // - There exists no value for x which would allow the result to contain zero.
8857   // - Values of x larger than the bitwidth are undefined and do not require a
8858   //   consistent result.
8859   // - Pushing the zero left requires shifting one bits in from the right.
8860   // A rotate left of ~1 is a nice way of achieving the desired result.
8861   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
8862       isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
8863     return DAG.getNode(ISD::ROTL, DL, VT, DAG.getConstant(~1, DL, VT),
8864                        N0.getOperand(1));
8865   }
8866 
8867   // Simplify: xor (op x...), (op y...)  -> (op (xor x, y))
8868   if (N0Opcode == N1.getOpcode())
8869     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
8870       return V;
8871 
8872   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
8873     return R;
8874   if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
8875     return R;
8876   if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
8877     return R;
8878 
8879   // Unfold  ((x ^ y) & m) ^ y  into  (x & m) | (y & ~m)  if profitable
8880   if (SDValue MM = unfoldMaskedMerge(N))
8881     return MM;
8882 
8883   // Simplify the expression using non-local knowledge.
8884   if (SimplifyDemandedBits(SDValue(N, 0)))
8885     return SDValue(N, 0);
8886 
8887   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
8888     return Combined;
8889 
8890   return SDValue();
8891 }
8892 
8893 /// If we have a shift-by-constant of a bitwise logic op that itself has a
8894 /// shift-by-constant operand with identical opcode, we may be able to convert
8895 /// that into 2 independent shifts followed by the logic op. This is a
8896 /// throughput improvement.
combineShiftOfShiftedLogic(SDNode * Shift,SelectionDAG & DAG)8897 static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
8898   // Match a one-use bitwise logic op.
8899   SDValue LogicOp = Shift->getOperand(0);
8900   if (!LogicOp.hasOneUse())
8901     return SDValue();
8902 
8903   unsigned LogicOpcode = LogicOp.getOpcode();
8904   if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
8905       LogicOpcode != ISD::XOR)
8906     return SDValue();
8907 
8908   // Find a matching one-use shift by constant.
8909   unsigned ShiftOpcode = Shift->getOpcode();
8910   SDValue C1 = Shift->getOperand(1);
8911   ConstantSDNode *C1Node = isConstOrConstSplat(C1);
8912   assert(C1Node && "Expected a shift with constant operand");
8913   const APInt &C1Val = C1Node->getAPIntValue();
8914   auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
8915                              const APInt *&ShiftAmtVal) {
8916     if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
8917       return false;
8918 
8919     ConstantSDNode *ShiftCNode = isConstOrConstSplat(V.getOperand(1));
8920     if (!ShiftCNode)
8921       return false;
8922 
8923     // Capture the shifted operand and shift amount value.
8924     ShiftOp = V.getOperand(0);
8925     ShiftAmtVal = &ShiftCNode->getAPIntValue();
8926 
8927     // Shift amount types do not have to match their operand type, so check that
8928     // the constants are the same width.
8929     if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
8930       return false;
8931 
8932     // The fold is not valid if the sum of the shift values exceeds bitwidth.
8933     if ((*ShiftAmtVal + C1Val).uge(V.getScalarValueSizeInBits()))
8934       return false;
8935 
8936     return true;
8937   };
8938 
8939   // Logic ops are commutative, so check each operand for a match.
8940   SDValue X, Y;
8941   const APInt *C0Val;
8942   if (matchFirstShift(LogicOp.getOperand(0), X, C0Val))
8943     Y = LogicOp.getOperand(1);
8944   else if (matchFirstShift(LogicOp.getOperand(1), X, C0Val))
8945     Y = LogicOp.getOperand(0);
8946   else
8947     return SDValue();
8948 
8949   // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
8950   SDLoc DL(Shift);
8951   EVT VT = Shift->getValueType(0);
8952   EVT ShiftAmtVT = Shift->getOperand(1).getValueType();
8953   SDValue ShiftSumC = DAG.getConstant(*C0Val + C1Val, DL, ShiftAmtVT);
8954   SDValue NewShift1 = DAG.getNode(ShiftOpcode, DL, VT, X, ShiftSumC);
8955   SDValue NewShift2 = DAG.getNode(ShiftOpcode, DL, VT, Y, C1);
8956   return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2);
8957 }
8958 
8959 /// Handle transforms common to the three shifts, when the shift amount is a
8960 /// constant.
8961 /// We are looking for: (shift being one of shl/sra/srl)
8962 ///   shift (binop X, C0), C1
8963 /// And want to transform into:
8964 ///   binop (shift X, C1), (shift C0, C1)
visitShiftByConstant(SDNode * N)8965 SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
8966   assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
8967 
8968   // Do not turn a 'not' into a regular xor.
8969   if (isBitwiseNot(N->getOperand(0)))
8970     return SDValue();
8971 
8972   // The inner binop must be one-use, since we want to replace it.
8973   SDValue LHS = N->getOperand(0);
8974   if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
8975     return SDValue();
8976 
8977   // Fold shift(bitop(shift(x,c1),y), c2) -> bitop(shift(x,c1+c2),shift(y,c2)).
8978   if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
8979     return R;
8980 
8981   // We want to pull some binops through shifts, so that we have (and (shift))
8982   // instead of (shift (and)), likewise for add, or, xor, etc.  This sort of
8983   // thing happens with address calculations, so it's important to canonicalize
8984   // it.
8985   switch (LHS.getOpcode()) {
8986   default:
8987     return SDValue();
8988   case ISD::OR:
8989   case ISD::XOR:
8990   case ISD::AND:
8991     break;
8992   case ISD::ADD:
8993     if (N->getOpcode() != ISD::SHL)
8994       return SDValue(); // only shl(add) not sr[al](add).
8995     break;
8996   }
8997 
8998   // FIXME: disable this unless the input to the binop is a shift by a constant
8999   // or is copy/select. Enable this in other cases when figure out it's exactly
9000   // profitable.
9001   SDValue BinOpLHSVal = LHS.getOperand(0);
9002   bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
9003                             BinOpLHSVal.getOpcode() == ISD::SRA ||
9004                             BinOpLHSVal.getOpcode() == ISD::SRL) &&
9005                            isa<ConstantSDNode>(BinOpLHSVal.getOperand(1));
9006   bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
9007                         BinOpLHSVal.getOpcode() == ISD::SELECT;
9008 
9009   if (!IsShiftByConstant && !IsCopyOrSelect)
9010     return SDValue();
9011 
9012   if (IsCopyOrSelect && N->hasOneUse())
9013     return SDValue();
9014 
9015   // Attempt to fold the constants, shifting the binop RHS by the shift amount.
9016   SDLoc DL(N);
9017   EVT VT = N->getValueType(0);
9018   if (SDValue NewRHS = DAG.FoldConstantArithmetic(
9019           N->getOpcode(), DL, VT, {LHS.getOperand(1), N->getOperand(1)})) {
9020     SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
9021                                    N->getOperand(1));
9022     return DAG.getNode(LHS.getOpcode(), DL, VT, NewShift, NewRHS);
9023   }
9024 
9025   return SDValue();
9026 }
9027 
distributeTruncateThroughAnd(SDNode * N)9028 SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
9029   assert(N->getOpcode() == ISD::TRUNCATE);
9030   assert(N->getOperand(0).getOpcode() == ISD::AND);
9031 
9032   // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
9033   EVT TruncVT = N->getValueType(0);
9034   if (N->hasOneUse() && N->getOperand(0).hasOneUse() &&
9035       TLI.isTypeDesirableForOp(ISD::AND, TruncVT)) {
9036     SDValue N01 = N->getOperand(0).getOperand(1);
9037     if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
9038       SDLoc DL(N);
9039       SDValue N00 = N->getOperand(0).getOperand(0);
9040       SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
9041       SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
9042       AddToWorklist(Trunc00.getNode());
9043       AddToWorklist(Trunc01.getNode());
9044       return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
9045     }
9046   }
9047 
9048   return SDValue();
9049 }
9050 
visitRotate(SDNode * N)9051 SDValue DAGCombiner::visitRotate(SDNode *N) {
9052   SDLoc dl(N);
9053   SDValue N0 = N->getOperand(0);
9054   SDValue N1 = N->getOperand(1);
9055   EVT VT = N->getValueType(0);
9056   unsigned Bitsize = VT.getScalarSizeInBits();
9057 
9058   // fold (rot x, 0) -> x
9059   if (isNullOrNullSplat(N1))
9060     return N0;
9061 
9062   // fold (rot x, c) -> x iff (c % BitSize) == 0
9063   if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
9064     APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
9065     if (DAG.MaskedValueIsZero(N1, ModuloMask))
9066       return N0;
9067   }
9068 
9069   // fold (rot x, c) -> (rot x, c % BitSize)
9070   bool OutOfRange = false;
9071   auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
9072     OutOfRange |= C->getAPIntValue().uge(Bitsize);
9073     return true;
9074   };
9075   if (ISD::matchUnaryPredicate(N1, MatchOutOfRange) && OutOfRange) {
9076     EVT AmtVT = N1.getValueType();
9077     SDValue Bits = DAG.getConstant(Bitsize, dl, AmtVT);
9078     if (SDValue Amt =
9079             DAG.FoldConstantArithmetic(ISD::UREM, dl, AmtVT, {N1, Bits}))
9080       return DAG.getNode(N->getOpcode(), dl, VT, N0, Amt);
9081   }
9082 
9083   // rot i16 X, 8 --> bswap X
9084   auto *RotAmtC = isConstOrConstSplat(N1);
9085   if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
9086       VT.getScalarSizeInBits() == 16 && hasOperation(ISD::BSWAP, VT))
9087     return DAG.getNode(ISD::BSWAP, dl, VT, N0);
9088 
9089   // Simplify the operands using demanded-bits information.
9090   if (SimplifyDemandedBits(SDValue(N, 0)))
9091     return SDValue(N, 0);
9092 
9093   // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
9094   if (N1.getOpcode() == ISD::TRUNCATE &&
9095       N1.getOperand(0).getOpcode() == ISD::AND) {
9096     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9097       return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
9098   }
9099 
9100   unsigned NextOp = N0.getOpcode();
9101 
9102   // fold (rot* (rot* x, c2), c1)
9103   //   -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize) + bitsize) % bitsize)
9104   if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
9105     SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
9106     SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
9107     if (C1 && C2 && C1->getValueType(0) == C2->getValueType(0)) {
9108       EVT ShiftVT = C1->getValueType(0);
9109       bool SameSide = (N->getOpcode() == NextOp);
9110       unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
9111       SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
9112       SDValue Norm1 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
9113                                                  {N1, BitsizeC});
9114       SDValue Norm2 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
9115                                                  {N0.getOperand(1), BitsizeC});
9116       if (Norm1 && Norm2)
9117         if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
9118                 CombineOp, dl, ShiftVT, {Norm1, Norm2})) {
9119           CombinedShift = DAG.FoldConstantArithmetic(ISD::ADD, dl, ShiftVT,
9120                                                      {CombinedShift, BitsizeC});
9121           SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
9122               ISD::UREM, dl, ShiftVT, {CombinedShift, BitsizeC});
9123           return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
9124                              CombinedShiftNorm);
9125         }
9126     }
9127   }
9128   return SDValue();
9129 }
9130 
visitSHL(SDNode * N)9131 SDValue DAGCombiner::visitSHL(SDNode *N) {
9132   SDValue N0 = N->getOperand(0);
9133   SDValue N1 = N->getOperand(1);
9134   if (SDValue V = DAG.simplifyShift(N0, N1))
9135     return V;
9136 
9137   EVT VT = N0.getValueType();
9138   EVT ShiftVT = N1.getValueType();
9139   unsigned OpSizeInBits = VT.getScalarSizeInBits();
9140 
9141   // fold (shl c1, c2) -> c1<<c2
9142   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N0, N1}))
9143     return C;
9144 
9145   // fold vector ops
9146   if (VT.isVector()) {
9147     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
9148       return FoldedVOp;
9149 
9150     BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);
9151     // If setcc produces all-one true value then:
9152     // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
9153     if (N1CV && N1CV->isConstant()) {
9154       if (N0.getOpcode() == ISD::AND) {
9155         SDValue N00 = N0->getOperand(0);
9156         SDValue N01 = N0->getOperand(1);
9157         BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01);
9158 
9159         if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
9160             TLI.getBooleanContents(N00.getOperand(0).getValueType()) ==
9161                 TargetLowering::ZeroOrNegativeOneBooleanContent) {
9162           if (SDValue C =
9163                   DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N01, N1}))
9164             return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C);
9165         }
9166       }
9167     }
9168   }
9169 
9170   if (SDValue NewSel = foldBinOpIntoSelect(N))
9171     return NewSel;
9172 
9173   // if (shl x, c) is known to be zero, return 0
9174   if (DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
9175     return DAG.getConstant(0, SDLoc(N), VT);
9176 
9177   // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
9178   if (N1.getOpcode() == ISD::TRUNCATE &&
9179       N1.getOperand(0).getOpcode() == ISD::AND) {
9180     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9181       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1);
9182   }
9183 
9184   if (SimplifyDemandedBits(SDValue(N, 0)))
9185     return SDValue(N, 0);
9186 
9187   // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
9188   if (N0.getOpcode() == ISD::SHL) {
9189     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
9190                                           ConstantSDNode *RHS) {
9191       APInt c1 = LHS->getAPIntValue();
9192       APInt c2 = RHS->getAPIntValue();
9193       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9194       return (c1 + c2).uge(OpSizeInBits);
9195     };
9196     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
9197       return DAG.getConstant(0, SDLoc(N), VT);
9198 
9199     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
9200                                        ConstantSDNode *RHS) {
9201       APInt c1 = LHS->getAPIntValue();
9202       APInt c2 = RHS->getAPIntValue();
9203       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9204       return (c1 + c2).ult(OpSizeInBits);
9205     };
9206     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
9207       SDLoc DL(N);
9208       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
9209       return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
9210     }
9211   }
9212 
9213   // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
9214   // For this to be valid, the second form must not preserve any of the bits
9215   // that are shifted out by the inner shift in the first form.  This means
9216   // the outer shift size must be >= the number of bits added by the ext.
9217   // As a corollary, we don't care what kind of ext it is.
9218   if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
9219        N0.getOpcode() == ISD::ANY_EXTEND ||
9220        N0.getOpcode() == ISD::SIGN_EXTEND) &&
9221       N0.getOperand(0).getOpcode() == ISD::SHL) {
9222     SDValue N0Op0 = N0.getOperand(0);
9223     SDValue InnerShiftAmt = N0Op0.getOperand(1);
9224     EVT InnerVT = N0Op0.getValueType();
9225     uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
9226 
9227     auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9228                                                          ConstantSDNode *RHS) {
9229       APInt c1 = LHS->getAPIntValue();
9230       APInt c2 = RHS->getAPIntValue();
9231       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9232       return c2.uge(OpSizeInBits - InnerBitwidth) &&
9233              (c1 + c2).uge(OpSizeInBits);
9234     };
9235     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
9236                                   /*AllowUndefs*/ false,
9237                                   /*AllowTypeMismatch*/ true))
9238       return DAG.getConstant(0, SDLoc(N), VT);
9239 
9240     auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9241                                                       ConstantSDNode *RHS) {
9242       APInt c1 = LHS->getAPIntValue();
9243       APInt c2 = RHS->getAPIntValue();
9244       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9245       return c2.uge(OpSizeInBits - InnerBitwidth) &&
9246              (c1 + c2).ult(OpSizeInBits);
9247     };
9248     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchInRange,
9249                                   /*AllowUndefs*/ false,
9250                                   /*AllowTypeMismatch*/ true)) {
9251       SDLoc DL(N);
9252       SDValue Ext = DAG.getNode(N0.getOpcode(), DL, VT, N0Op0.getOperand(0));
9253       SDValue Sum = DAG.getZExtOrTrunc(InnerShiftAmt, DL, ShiftVT);
9254       Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, Sum, N1);
9255       return DAG.getNode(ISD::SHL, DL, VT, Ext, Sum);
9256     }
9257   }
9258 
9259   // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
9260   // Only fold this if the inner zext has no other uses to avoid increasing
9261   // the total number of instructions.
9262   if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9263       N0.getOperand(0).getOpcode() == ISD::SRL) {
9264     SDValue N0Op0 = N0.getOperand(0);
9265     SDValue InnerShiftAmt = N0Op0.getOperand(1);
9266 
9267     auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
9268       APInt c1 = LHS->getAPIntValue();
9269       APInt c2 = RHS->getAPIntValue();
9270       zeroExtendToMatch(c1, c2);
9271       return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2);
9272     };
9273     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual,
9274                                   /*AllowUndefs*/ false,
9275                                   /*AllowTypeMismatch*/ true)) {
9276       SDLoc DL(N);
9277       EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType();
9278       SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT);
9279       NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL);
9280       AddToWorklist(NewSHL.getNode());
9281       return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
9282     }
9283   }
9284 
9285   if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
9286     auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
9287                                            ConstantSDNode *RHS) {
9288       const APInt &LHSC = LHS->getAPIntValue();
9289       const APInt &RHSC = RHS->getAPIntValue();
9290       return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
9291              LHSC.getZExtValue() <= RHSC.getZExtValue();
9292     };
9293 
9294     SDLoc DL(N);
9295 
9296     // fold (shl (sr[la] exact X,  C1), C2) -> (shl    X, (C2-C1)) if C1 <= C2
9297     // fold (shl (sr[la] exact X,  C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
9298     if (N0->getFlags().hasExact()) {
9299       if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
9300                                     /*AllowUndefs*/ false,
9301                                     /*AllowTypeMismatch*/ true)) {
9302         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9303         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
9304         return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
9305       }
9306       if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
9307                                     /*AllowUndefs*/ false,
9308                                     /*AllowTypeMismatch*/ true)) {
9309         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9310         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
9311         return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Diff);
9312       }
9313     }
9314 
9315     // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
9316     //                               (and (srl x, (sub c1, c2), MASK)
9317     // Only fold this if the inner shift has no other uses -- if it does,
9318     // folding this will increase the total number of instructions.
9319     if (N0.getOpcode() == ISD::SRL &&
9320         (N0.getOperand(1) == N1 || N0.hasOneUse()) &&
9321         TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
9322       if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
9323                                     /*AllowUndefs*/ false,
9324                                     /*AllowTypeMismatch*/ true)) {
9325         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9326         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
9327         SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9328         Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N01);
9329         Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, Diff);
9330         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
9331         return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9332       }
9333       if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
9334                                     /*AllowUndefs*/ false,
9335                                     /*AllowTypeMismatch*/ true)) {
9336         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9337         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
9338         SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9339         Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N1);
9340         SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
9341         return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9342       }
9343     }
9344   }
9345 
9346   // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
9347   if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
9348       isConstantOrConstantVector(N1, /* No Opaques */ true)) {
9349     SDLoc DL(N);
9350     SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
9351     SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
9352     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
9353   }
9354 
9355   // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
9356   // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
9357   // Variant of version done on multiply, except mul by a power of 2 is turned
9358   // into a shift.
9359   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
9360       N0->hasOneUse() &&
9361       isConstantOrConstantVector(N1, /* No Opaques */ true) &&
9362       isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true) &&
9363       TLI.isDesirableToCommuteWithShift(N, Level)) {
9364     SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
9365     SDValue Shl1 = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
9366     AddToWorklist(Shl0.getNode());
9367     AddToWorklist(Shl1.getNode());
9368     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Shl0, Shl1);
9369   }
9370 
9371   // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
9372   if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
9373     SDValue N01 = N0.getOperand(1);
9374     if (SDValue Shl =
9375             DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1}))
9376       return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl);
9377   }
9378 
9379   ConstantSDNode *N1C = isConstOrConstSplat(N1);
9380   if (N1C && !N1C->isOpaque())
9381     if (SDValue NewSHL = visitShiftByConstant(N))
9382       return NewSHL;
9383 
9384   // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
9385   if (N0.getOpcode() == ISD::VSCALE && N1C) {
9386     const APInt &C0 = N0.getConstantOperandAPInt(0);
9387     const APInt &C1 = N1C->getAPIntValue();
9388     return DAG.getVScale(SDLoc(N), VT, C0 << C1);
9389   }
9390 
9391   // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
9392   APInt ShlVal;
9393   if (N0.getOpcode() == ISD::STEP_VECTOR &&
9394       ISD::isConstantSplatVector(N1.getNode(), ShlVal)) {
9395     const APInt &C0 = N0.getConstantOperandAPInt(0);
9396     if (ShlVal.ult(C0.getBitWidth())) {
9397       APInt NewStep = C0 << ShlVal;
9398       return DAG.getStepVector(SDLoc(N), VT, NewStep);
9399     }
9400   }
9401 
9402   return SDValue();
9403 }
9404 
9405 // Transform a right shift of a multiply into a multiply-high.
9406 // Examples:
9407 // (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
9408 // (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
combineShiftToMULH(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)9409 static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG,
9410                                   const TargetLowering &TLI) {
9411   assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
9412          "SRL or SRA node is required here!");
9413 
9414   // Check the shift amount. Proceed with the transformation if the shift
9415   // amount is constant.
9416   ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1));
9417   if (!ShiftAmtSrc)
9418     return SDValue();
9419 
9420   SDLoc DL(N);
9421 
9422   // The operation feeding into the shift must be a multiply.
9423   SDValue ShiftOperand = N->getOperand(0);
9424   if (ShiftOperand.getOpcode() != ISD::MUL)
9425     return SDValue();
9426 
9427   // Both operands must be equivalent extend nodes.
9428   SDValue LeftOp = ShiftOperand.getOperand(0);
9429   SDValue RightOp = ShiftOperand.getOperand(1);
9430 
9431   bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
9432   bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
9433 
9434   if (!IsSignExt && !IsZeroExt)
9435     return SDValue();
9436 
9437   EVT NarrowVT = LeftOp.getOperand(0).getValueType();
9438   unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
9439 
9440   // return true if U may use the lower bits of its operands
9441   auto UserOfLowerBits = [NarrowVTSize](SDNode *U) {
9442     if (U->getOpcode() != ISD::SRL && U->getOpcode() != ISD::SRA) {
9443       return true;
9444     }
9445     ConstantSDNode *UShiftAmtSrc = isConstOrConstSplat(U->getOperand(1));
9446     if (!UShiftAmtSrc) {
9447       return true;
9448     }
9449     unsigned UShiftAmt = UShiftAmtSrc->getZExtValue();
9450     return UShiftAmt < NarrowVTSize;
9451   };
9452 
9453   // If the lower part of the MUL is also used and MUL_LOHI is supported
9454   // do not introduce the MULH in favor of MUL_LOHI
9455   unsigned MulLoHiOp = IsSignExt ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
9456   if (!ShiftOperand.hasOneUse() &&
9457       TLI.isOperationLegalOrCustom(MulLoHiOp, NarrowVT) &&
9458       llvm::any_of(ShiftOperand->uses(), UserOfLowerBits)) {
9459     return SDValue();
9460   }
9461 
9462   SDValue MulhRightOp;
9463   if (ConstantSDNode *Constant = isConstOrConstSplat(RightOp)) {
9464     unsigned ActiveBits = IsSignExt
9465                               ? Constant->getAPIntValue().getMinSignedBits()
9466                               : Constant->getAPIntValue().getActiveBits();
9467     if (ActiveBits > NarrowVTSize)
9468       return SDValue();
9469     MulhRightOp = DAG.getConstant(
9470         Constant->getAPIntValue().trunc(NarrowVT.getScalarSizeInBits()), DL,
9471         NarrowVT);
9472   } else {
9473     if (LeftOp.getOpcode() != RightOp.getOpcode())
9474       return SDValue();
9475     // Check that the two extend nodes are the same type.
9476     if (NarrowVT != RightOp.getOperand(0).getValueType())
9477       return SDValue();
9478     MulhRightOp = RightOp.getOperand(0);
9479   }
9480 
9481   EVT WideVT = LeftOp.getValueType();
9482   // Proceed with the transformation if the wide types match.
9483   assert((WideVT == RightOp.getValueType()) &&
9484          "Cannot have a multiply node with two different operand types.");
9485 
9486   // Proceed with the transformation if the wide type is twice as large
9487   // as the narrow type.
9488   if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
9489     return SDValue();
9490 
9491   // Check the shift amount with the narrow type size.
9492   // Proceed with the transformation if the shift amount is the width
9493   // of the narrow type.
9494   unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
9495   if (ShiftAmt != NarrowVTSize)
9496     return SDValue();
9497 
9498   // If the operation feeding into the MUL is a sign extend (sext),
9499   // we use mulhs. Othewise, zero extends (zext) use mulhu.
9500   unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
9501 
9502   // Combine to mulh if mulh is legal/custom for the narrow type on the target.
9503   if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT))
9504     return SDValue();
9505 
9506   SDValue Result =
9507       DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0), MulhRightOp);
9508   return (N->getOpcode() == ISD::SRA ? DAG.getSExtOrTrunc(Result, DL, WideVT)
9509                                      : DAG.getZExtOrTrunc(Result, DL, WideVT));
9510 }
9511 
visitSRA(SDNode * N)9512 SDValue DAGCombiner::visitSRA(SDNode *N) {
9513   SDValue N0 = N->getOperand(0);
9514   SDValue N1 = N->getOperand(1);
9515   if (SDValue V = DAG.simplifyShift(N0, N1))
9516     return V;
9517 
9518   EVT VT = N0.getValueType();
9519   unsigned OpSizeInBits = VT.getScalarSizeInBits();
9520 
9521   // fold (sra c1, c2) -> (sra c1, c2)
9522   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, {N0, N1}))
9523     return C;
9524 
9525   // Arithmetic shifting an all-sign-bit value is a no-op.
9526   // fold (sra 0, x) -> 0
9527   // fold (sra -1, x) -> -1
9528   if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
9529     return N0;
9530 
9531   // fold vector ops
9532   if (VT.isVector())
9533     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
9534       return FoldedVOp;
9535 
9536   if (SDValue NewSel = foldBinOpIntoSelect(N))
9537     return NewSel;
9538 
9539   // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
9540   // sext_inreg.
9541   ConstantSDNode *N1C = isConstOrConstSplat(N1);
9542   if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
9543     unsigned LowBits = OpSizeInBits - (unsigned)N1C->getZExtValue();
9544     EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), LowBits);
9545     if (VT.isVector())
9546       ExtVT = EVT::getVectorVT(*DAG.getContext(), ExtVT,
9547                                VT.getVectorElementCount());
9548     if (!LegalOperations ||
9549         TLI.getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) ==
9550         TargetLowering::Legal)
9551       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT,
9552                          N0.getOperand(0), DAG.getValueType(ExtVT));
9553     // Even if we can't convert to sext_inreg, we might be able to remove
9554     // this shift pair if the input is already sign extended.
9555     if (DAG.ComputeNumSignBits(N0.getOperand(0)) > N1C->getZExtValue())
9556       return N0.getOperand(0);
9557   }
9558 
9559   // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
9560   // clamp (add c1, c2) to max shift.
9561   if (N0.getOpcode() == ISD::SRA) {
9562     SDLoc DL(N);
9563     EVT ShiftVT = N1.getValueType();
9564     EVT ShiftSVT = ShiftVT.getScalarType();
9565     SmallVector<SDValue, 16> ShiftValues;
9566 
9567     auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
9568       APInt c1 = LHS->getAPIntValue();
9569       APInt c2 = RHS->getAPIntValue();
9570       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9571       APInt Sum = c1 + c2;
9572       unsigned ShiftSum =
9573           Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
9574       ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
9575       return true;
9576     };
9577     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
9578       SDValue ShiftValue;
9579       if (N1.getOpcode() == ISD::BUILD_VECTOR)
9580         ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
9581       else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
9582         assert(ShiftValues.size() == 1 &&
9583                "Expected matchBinaryPredicate to return one element for "
9584                "SPLAT_VECTORs");
9585         ShiftValue = DAG.getSplatVector(ShiftVT, DL, ShiftValues[0]);
9586       } else
9587         ShiftValue = ShiftValues[0];
9588       return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
9589     }
9590   }
9591 
9592   // fold (sra (shl X, m), (sub result_size, n))
9593   // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
9594   // result_size - n != m.
9595   // If truncate is free for the target sext(shl) is likely to result in better
9596   // code.
9597   if (N0.getOpcode() == ISD::SHL && N1C) {
9598     // Get the two constanst of the shifts, CN0 = m, CN = n.
9599     const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1));
9600     if (N01C) {
9601       LLVMContext &Ctx = *DAG.getContext();
9602       // Determine what the truncate's result bitsize and type would be.
9603       EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue());
9604 
9605       if (VT.isVector())
9606         TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
9607 
9608       // Determine the residual right-shift amount.
9609       int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
9610 
9611       // If the shift is not a no-op (in which case this should be just a sign
9612       // extend already), the truncated to type is legal, sign_extend is legal
9613       // on that type, and the truncate to that type is both legal and free,
9614       // perform the transform.
9615       if ((ShiftAmt > 0) &&
9616           TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
9617           TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
9618           TLI.isTruncateFree(VT, TruncVT)) {
9619         SDLoc DL(N);
9620         SDValue Amt = DAG.getConstant(ShiftAmt, DL,
9621             getShiftAmountTy(N0.getOperand(0).getValueType()));
9622         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
9623                                     N0.getOperand(0), Amt);
9624         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
9625                                     Shift);
9626         return DAG.getNode(ISD::SIGN_EXTEND, DL,
9627                            N->getValueType(0), Trunc);
9628       }
9629     }
9630   }
9631 
9632   // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
9633   //   sra (add (shl X, N1C), AddC), N1C -->
9634   //   sext (add (trunc X to (width - N1C)), AddC')
9635   //   sra (sub AddC, (shl X, N1C)), N1C -->
9636   //   sext (sub AddC1',(trunc X to (width - N1C)))
9637   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
9638       N0.hasOneUse()) {
9639     bool IsAdd = N0.getOpcode() == ISD::ADD;
9640     SDValue Shl = N0.getOperand(IsAdd ? 0 : 1);
9641     if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(1) == N1 &&
9642         Shl.hasOneUse()) {
9643       // TODO: AddC does not need to be a splat.
9644       if (ConstantSDNode *AddC =
9645               isConstOrConstSplat(N0.getOperand(IsAdd ? 1 : 0))) {
9646         // Determine what the truncate's type would be and ask the target if
9647         // that is a free operation.
9648         LLVMContext &Ctx = *DAG.getContext();
9649         unsigned ShiftAmt = N1C->getZExtValue();
9650         EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - ShiftAmt);
9651         if (VT.isVector())
9652           TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
9653 
9654         // TODO: The simple type check probably belongs in the default hook
9655         //       implementation and/or target-specific overrides (because
9656         //       non-simple types likely require masking when legalized), but
9657         //       that restriction may conflict with other transforms.
9658         if (TruncVT.isSimple() && isTypeLegal(TruncVT) &&
9659             TLI.isTruncateFree(VT, TruncVT)) {
9660           SDLoc DL(N);
9661           SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT);
9662           SDValue ShiftC =
9663               DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt).trunc(
9664                                   TruncVT.getScalarSizeInBits()),
9665                               DL, TruncVT);
9666           SDValue Add;
9667           if (IsAdd)
9668             Add = DAG.getNode(ISD::ADD, DL, TruncVT, Trunc, ShiftC);
9669           else
9670             Add = DAG.getNode(ISD::SUB, DL, TruncVT, ShiftC, Trunc);
9671           return DAG.getSExtOrTrunc(Add, DL, VT);
9672         }
9673       }
9674     }
9675   }
9676 
9677   // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
9678   if (N1.getOpcode() == ISD::TRUNCATE &&
9679       N1.getOperand(0).getOpcode() == ISD::AND) {
9680     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9681       return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1);
9682   }
9683 
9684   // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
9685   // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
9686   //      if c1 is equal to the number of bits the trunc removes
9687   // TODO - support non-uniform vector shift amounts.
9688   if (N0.getOpcode() == ISD::TRUNCATE &&
9689       (N0.getOperand(0).getOpcode() == ISD::SRL ||
9690        N0.getOperand(0).getOpcode() == ISD::SRA) &&
9691       N0.getOperand(0).hasOneUse() &&
9692       N0.getOperand(0).getOperand(1).hasOneUse() && N1C) {
9693     SDValue N0Op0 = N0.getOperand(0);
9694     if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) {
9695       EVT LargeVT = N0Op0.getValueType();
9696       unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
9697       if (LargeShift->getAPIntValue() == TruncBits) {
9698         SDLoc DL(N);
9699         EVT LargeShiftVT = getShiftAmountTy(LargeVT);
9700         SDValue Amt = DAG.getZExtOrTrunc(N1, DL, LargeShiftVT);
9701         Amt = DAG.getNode(ISD::ADD, DL, LargeShiftVT, Amt,
9702                           DAG.getConstant(TruncBits, DL, LargeShiftVT));
9703         SDValue SRA =
9704             DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt);
9705         return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA);
9706       }
9707     }
9708   }
9709 
9710   // Simplify, based on bits shifted out of the LHS.
9711   if (SimplifyDemandedBits(SDValue(N, 0)))
9712     return SDValue(N, 0);
9713 
9714   // If the sign bit is known to be zero, switch this to a SRL.
9715   if (DAG.SignBitIsZero(N0))
9716     return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1);
9717 
9718   if (N1C && !N1C->isOpaque())
9719     if (SDValue NewSRA = visitShiftByConstant(N))
9720       return NewSRA;
9721 
9722   // Try to transform this shift into a multiply-high if
9723   // it matches the appropriate pattern detected in combineShiftToMULH.
9724   if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
9725     return MULH;
9726 
9727   // Attempt to convert a sra of a load into a narrower sign-extending load.
9728   if (SDValue NarrowLoad = reduceLoadWidth(N))
9729     return NarrowLoad;
9730 
9731   return SDValue();
9732 }
9733 
visitSRL(SDNode * N)9734 SDValue DAGCombiner::visitSRL(SDNode *N) {
9735   SDValue N0 = N->getOperand(0);
9736   SDValue N1 = N->getOperand(1);
9737   if (SDValue V = DAG.simplifyShift(N0, N1))
9738     return V;
9739 
9740   EVT VT = N0.getValueType();
9741   EVT ShiftVT = N1.getValueType();
9742   unsigned OpSizeInBits = VT.getScalarSizeInBits();
9743 
9744   // fold (srl c1, c2) -> c1 >>u c2
9745   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, {N0, N1}))
9746     return C;
9747 
9748   // fold vector ops
9749   if (VT.isVector())
9750     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
9751       return FoldedVOp;
9752 
9753   if (SDValue NewSel = foldBinOpIntoSelect(N))
9754     return NewSel;
9755 
9756   // if (srl x, c) is known to be zero, return 0
9757   ConstantSDNode *N1C = isConstOrConstSplat(N1);
9758   if (N1C &&
9759       DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
9760     return DAG.getConstant(0, SDLoc(N), VT);
9761 
9762   // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
9763   if (N0.getOpcode() == ISD::SRL) {
9764     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
9765                                           ConstantSDNode *RHS) {
9766       APInt c1 = LHS->getAPIntValue();
9767       APInt c2 = RHS->getAPIntValue();
9768       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9769       return (c1 + c2).uge(OpSizeInBits);
9770     };
9771     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
9772       return DAG.getConstant(0, SDLoc(N), VT);
9773 
9774     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
9775                                        ConstantSDNode *RHS) {
9776       APInt c1 = LHS->getAPIntValue();
9777       APInt c2 = RHS->getAPIntValue();
9778       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9779       return (c1 + c2).ult(OpSizeInBits);
9780     };
9781     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
9782       SDLoc DL(N);
9783       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
9784       return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
9785     }
9786   }
9787 
9788   if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
9789       N0.getOperand(0).getOpcode() == ISD::SRL) {
9790     SDValue InnerShift = N0.getOperand(0);
9791     // TODO - support non-uniform vector shift amounts.
9792     if (auto *N001C = isConstOrConstSplat(InnerShift.getOperand(1))) {
9793       uint64_t c1 = N001C->getZExtValue();
9794       uint64_t c2 = N1C->getZExtValue();
9795       EVT InnerShiftVT = InnerShift.getValueType();
9796       EVT ShiftAmtVT = InnerShift.getOperand(1).getValueType();
9797       uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
9798       // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
9799       // This is only valid if the OpSizeInBits + c1 = size of inner shift.
9800       if (c1 + OpSizeInBits == InnerShiftSize) {
9801         SDLoc DL(N);
9802         if (c1 + c2 >= InnerShiftSize)
9803           return DAG.getConstant(0, DL, VT);
9804         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
9805         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
9806                                        InnerShift.getOperand(0), NewShiftAmt);
9807         return DAG.getNode(ISD::TRUNCATE, DL, VT, NewShift);
9808       }
9809       // In the more general case, we can clear the high bits after the shift:
9810       // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
9811       if (N0.hasOneUse() && InnerShift.hasOneUse() &&
9812           c1 + c2 < InnerShiftSize) {
9813         SDLoc DL(N);
9814         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
9815         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
9816                                        InnerShift.getOperand(0), NewShiftAmt);
9817         SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(InnerShiftSize,
9818                                                             OpSizeInBits - c2),
9819                                        DL, InnerShiftVT);
9820         SDValue And = DAG.getNode(ISD::AND, DL, InnerShiftVT, NewShift, Mask);
9821         return DAG.getNode(ISD::TRUNCATE, DL, VT, And);
9822       }
9823     }
9824   }
9825 
9826   // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
9827   //                               (and (srl x, (sub c2, c1), MASK)
9828   if (N0.getOpcode() == ISD::SHL &&
9829       (N0.getOperand(1) == N1 || N0->hasOneUse()) &&
9830       TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
9831     auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
9832                                            ConstantSDNode *RHS) {
9833       const APInt &LHSC = LHS->getAPIntValue();
9834       const APInt &RHSC = RHS->getAPIntValue();
9835       return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
9836              LHSC.getZExtValue() <= RHSC.getZExtValue();
9837     };
9838     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
9839                                   /*AllowUndefs*/ false,
9840                                   /*AllowTypeMismatch*/ true)) {
9841       SDLoc DL(N);
9842       SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9843       SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
9844       SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9845       Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N01);
9846       Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, Diff);
9847       SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
9848       return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9849     }
9850     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
9851                                   /*AllowUndefs*/ false,
9852                                   /*AllowTypeMismatch*/ true)) {
9853       SDLoc DL(N);
9854       SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9855       SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
9856       SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9857       Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N1);
9858       SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
9859       return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9860     }
9861   }
9862 
9863   // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
9864   // TODO - support non-uniform vector shift amounts.
9865   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
9866     // Shifting in all undef bits?
9867     EVT SmallVT = N0.getOperand(0).getValueType();
9868     unsigned BitSize = SmallVT.getScalarSizeInBits();
9869     if (N1C->getAPIntValue().uge(BitSize))
9870       return DAG.getUNDEF(VT);
9871 
9872     if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
9873       uint64_t ShiftAmt = N1C->getZExtValue();
9874       SDLoc DL0(N0);
9875       SDValue SmallShift = DAG.getNode(ISD::SRL, DL0, SmallVT,
9876                                        N0.getOperand(0),
9877                           DAG.getConstant(ShiftAmt, DL0,
9878                                           getShiftAmountTy(SmallVT)));
9879       AddToWorklist(SmallShift.getNode());
9880       APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
9881       SDLoc DL(N);
9882       return DAG.getNode(ISD::AND, DL, VT,
9883                          DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
9884                          DAG.getConstant(Mask, DL, VT));
9885     }
9886   }
9887 
9888   // fold (srl (sra X, Y), 31) -> (srl X, 31).  This srl only looks at the sign
9889   // bit, which is unmodified by sra.
9890   if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
9891     if (N0.getOpcode() == ISD::SRA)
9892       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), N1);
9893   }
9894 
9895   // fold (srl (ctlz x), "5") -> x  iff x has one bit set (the low bit).
9896   if (N1C && N0.getOpcode() == ISD::CTLZ &&
9897       N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
9898     KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
9899 
9900     // If any of the input bits are KnownOne, then the input couldn't be all
9901     // zeros, thus the result of the srl will always be zero.
9902     if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
9903 
9904     // If all of the bits input the to ctlz node are known to be zero, then
9905     // the result of the ctlz is "32" and the result of the shift is one.
9906     APInt UnknownBits = ~Known.Zero;
9907     if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
9908 
9909     // Otherwise, check to see if there is exactly one bit input to the ctlz.
9910     if (UnknownBits.isPowerOf2()) {
9911       // Okay, we know that only that the single bit specified by UnknownBits
9912       // could be set on input to the CTLZ node. If this bit is set, the SRL
9913       // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
9914       // to an SRL/XOR pair, which is likely to simplify more.
9915       unsigned ShAmt = UnknownBits.countTrailingZeros();
9916       SDValue Op = N0.getOperand(0);
9917 
9918       if (ShAmt) {
9919         SDLoc DL(N0);
9920         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
9921                   DAG.getConstant(ShAmt, DL,
9922                                   getShiftAmountTy(Op.getValueType())));
9923         AddToWorklist(Op.getNode());
9924       }
9925 
9926       SDLoc DL(N);
9927       return DAG.getNode(ISD::XOR, DL, VT,
9928                          Op, DAG.getConstant(1, DL, VT));
9929     }
9930   }
9931 
9932   // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
9933   if (N1.getOpcode() == ISD::TRUNCATE &&
9934       N1.getOperand(0).getOpcode() == ISD::AND) {
9935     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9936       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, NewOp1);
9937   }
9938 
9939   // fold operands of srl based on knowledge that the low bits are not
9940   // demanded.
9941   if (SimplifyDemandedBits(SDValue(N, 0)))
9942     return SDValue(N, 0);
9943 
9944   if (N1C && !N1C->isOpaque())
9945     if (SDValue NewSRL = visitShiftByConstant(N))
9946       return NewSRL;
9947 
9948   // Attempt to convert a srl of a load into a narrower zero-extending load.
9949   if (SDValue NarrowLoad = reduceLoadWidth(N))
9950     return NarrowLoad;
9951 
9952   // Here is a common situation. We want to optimize:
9953   //
9954   //   %a = ...
9955   //   %b = and i32 %a, 2
9956   //   %c = srl i32 %b, 1
9957   //   brcond i32 %c ...
9958   //
9959   // into
9960   //
9961   //   %a = ...
9962   //   %b = and %a, 2
9963   //   %c = setcc eq %b, 0
9964   //   brcond %c ...
9965   //
9966   // However when after the source operand of SRL is optimized into AND, the SRL
9967   // itself may not be optimized further. Look for it and add the BRCOND into
9968   // the worklist.
9969   //
9970   // The also tends to happen for binary operations when SimplifyDemandedBits
9971   // is involved.
9972   //
9973   // FIXME: This is unecessary if we process the DAG in topological order,
9974   // which we plan to do. This workaround can be removed once the DAG is
9975   // processed in topological order.
9976   if (N->hasOneUse()) {
9977     SDNode *Use = *N->use_begin();
9978 
9979     // Look pass the truncate.
9980     if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse())
9981       Use = *Use->use_begin();
9982 
9983     if (Use->getOpcode() == ISD::BRCOND || Use->getOpcode() == ISD::AND ||
9984         Use->getOpcode() == ISD::OR || Use->getOpcode() == ISD::XOR)
9985       AddToWorklist(Use);
9986   }
9987 
9988   // Try to transform this shift into a multiply-high if
9989   // it matches the appropriate pattern detected in combineShiftToMULH.
9990   if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
9991     return MULH;
9992 
9993   return SDValue();
9994 }
9995 
visitFunnelShift(SDNode * N)9996 SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
9997   EVT VT = N->getValueType(0);
9998   SDValue N0 = N->getOperand(0);
9999   SDValue N1 = N->getOperand(1);
10000   SDValue N2 = N->getOperand(2);
10001   bool IsFSHL = N->getOpcode() == ISD::FSHL;
10002   unsigned BitWidth = VT.getScalarSizeInBits();
10003 
10004   // fold (fshl N0, N1, 0) -> N0
10005   // fold (fshr N0, N1, 0) -> N1
10006   if (isPowerOf2_32(BitWidth))
10007     if (DAG.MaskedValueIsZero(
10008             N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
10009       return IsFSHL ? N0 : N1;
10010 
10011   auto IsUndefOrZero = [](SDValue V) {
10012     return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
10013   };
10014 
10015   // TODO - support non-uniform vector shift amounts.
10016   if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
10017     EVT ShAmtTy = N2.getValueType();
10018 
10019     // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
10020     if (Cst->getAPIntValue().uge(BitWidth)) {
10021       uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
10022       return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0, N1,
10023                          DAG.getConstant(RotAmt, SDLoc(N), ShAmtTy));
10024     }
10025 
10026     unsigned ShAmt = Cst->getZExtValue();
10027     if (ShAmt == 0)
10028       return IsFSHL ? N0 : N1;
10029 
10030     // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
10031     // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
10032     // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
10033     // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
10034     if (IsUndefOrZero(N0))
10035       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1,
10036                          DAG.getConstant(IsFSHL ? BitWidth - ShAmt : ShAmt,
10037                                          SDLoc(N), ShAmtTy));
10038     if (IsUndefOrZero(N1))
10039       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0,
10040                          DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt,
10041                                          SDLoc(N), ShAmtTy));
10042 
10043     // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10044     // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10045     // TODO - bigendian support once we have test coverage.
10046     // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
10047     // TODO - permit LHS EXTLOAD if extensions are shifted out.
10048     if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
10049         !DAG.getDataLayout().isBigEndian()) {
10050       auto *LHS = dyn_cast<LoadSDNode>(N0);
10051       auto *RHS = dyn_cast<LoadSDNode>(N1);
10052       if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
10053           LHS->getAddressSpace() == RHS->getAddressSpace() &&
10054           (LHS->hasOneUse() || RHS->hasOneUse()) && ISD::isNON_EXTLoad(RHS) &&
10055           ISD::isNON_EXTLoad(LHS)) {
10056         if (DAG.areNonVolatileConsecutiveLoads(LHS, RHS, BitWidth / 8, 1)) {
10057           SDLoc DL(RHS);
10058           uint64_t PtrOff =
10059               IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
10060           Align NewAlign = commonAlignment(RHS->getAlign(), PtrOff);
10061           unsigned Fast = 0;
10062           if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
10063                                      RHS->getAddressSpace(), NewAlign,
10064                                      RHS->getMemOperand()->getFlags(), &Fast) &&
10065               Fast) {
10066             SDValue NewPtr = DAG.getMemBasePlusOffset(
10067                 RHS->getBasePtr(), TypeSize::Fixed(PtrOff), DL);
10068             AddToWorklist(NewPtr.getNode());
10069             SDValue Load = DAG.getLoad(
10070                 VT, DL, RHS->getChain(), NewPtr,
10071                 RHS->getPointerInfo().getWithOffset(PtrOff), NewAlign,
10072                 RHS->getMemOperand()->getFlags(), RHS->getAAInfo());
10073             // Replace the old load's chain with the new load's chain.
10074             WorklistRemover DeadNodes(*this);
10075             DAG.ReplaceAllUsesOfValueWith(N1.getValue(1), Load.getValue(1));
10076             return Load;
10077           }
10078         }
10079       }
10080     }
10081   }
10082 
10083   // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
10084   // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
10085   // iff We know the shift amount is in range.
10086   // TODO: when is it worth doing SUB(BW, N2) as well?
10087   if (isPowerOf2_32(BitWidth)) {
10088     APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
10089     if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
10090       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1, N2);
10091     if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
10092       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N2);
10093   }
10094 
10095   // fold (fshl N0, N0, N2) -> (rotl N0, N2)
10096   // fold (fshr N0, N0, N2) -> (rotr N0, N2)
10097   // TODO: Investigate flipping this rotate if only one is legal, if funnel shift
10098   // is legal as well we might be better off avoiding non-constant (BW - N2).
10099   unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
10100   if (N0 == N1 && hasOperation(RotOpc, VT))
10101     return DAG.getNode(RotOpc, SDLoc(N), VT, N0, N2);
10102 
10103   // Simplify, based on bits shifted out of N0/N1.
10104   if (SimplifyDemandedBits(SDValue(N, 0)))
10105     return SDValue(N, 0);
10106 
10107   return SDValue();
10108 }
10109 
visitSHLSAT(SDNode * N)10110 SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
10111   SDValue N0 = N->getOperand(0);
10112   SDValue N1 = N->getOperand(1);
10113   if (SDValue V = DAG.simplifyShift(N0, N1))
10114     return V;
10115 
10116   EVT VT = N0.getValueType();
10117 
10118   // fold (*shlsat c1, c2) -> c1<<c2
10119   if (SDValue C =
10120           DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, {N0, N1}))
10121     return C;
10122 
10123   ConstantSDNode *N1C = isConstOrConstSplat(N1);
10124 
10125   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) {
10126     // fold (sshlsat x, c) -> (shl x, c)
10127     if (N->getOpcode() == ISD::SSHLSAT && N1C &&
10128         N1C->getAPIntValue().ult(DAG.ComputeNumSignBits(N0)))
10129       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N1);
10130 
10131     // fold (ushlsat x, c) -> (shl x, c)
10132     if (N->getOpcode() == ISD::USHLSAT && N1C &&
10133         N1C->getAPIntValue().ule(
10134             DAG.computeKnownBits(N0).countMinLeadingZeros()))
10135       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N1);
10136   }
10137 
10138   return SDValue();
10139 }
10140 
10141 // Given a ABS node, detect the following pattern:
10142 // (ABS (SUB (EXTEND a), (EXTEND b))).
10143 // Generates UABD/SABD instruction.
foldABSToABD(SDNode * N)10144 SDValue DAGCombiner::foldABSToABD(SDNode *N) {
10145   EVT VT = N->getValueType(0);
10146   SDValue AbsOp1 = N->getOperand(0);
10147   SDValue Op0, Op1;
10148 
10149   if (AbsOp1.getOpcode() != ISD::SUB)
10150     return SDValue();
10151 
10152   Op0 = AbsOp1.getOperand(0);
10153   Op1 = AbsOp1.getOperand(1);
10154 
10155   unsigned Opc0 = Op0.getOpcode();
10156   // Check if the operands of the sub are (zero|sign)-extended.
10157   if (Opc0 != Op1.getOpcode() ||
10158       (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND)) {
10159     // fold (abs (sub nsw x, y)) -> abds(x, y)
10160     if (AbsOp1->getFlags().hasNoSignedWrap() &&
10161         TLI.isOperationLegalOrCustom(ISD::ABDS, VT))
10162       return DAG.getNode(ISD::ABDS, SDLoc(N), VT, Op0, Op1);
10163     return SDValue();
10164   }
10165 
10166   EVT VT1 = Op0.getOperand(0).getValueType();
10167   EVT VT2 = Op1.getOperand(0).getValueType();
10168   unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU;
10169 
10170   // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
10171   // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
10172   // NOTE: Extensions must be equivalent.
10173   if (VT1 == VT2 && TLI.isOperationLegalOrCustom(ABDOpcode, VT1)) {
10174     Op0 = Op0.getOperand(0);
10175     Op1 = Op1.getOperand(0);
10176     SDValue ABD = DAG.getNode(ABDOpcode, SDLoc(N), VT1, Op0, Op1);
10177     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, ABD);
10178   }
10179 
10180   // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
10181   // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
10182   if (TLI.isOperationLegalOrCustom(ABDOpcode, VT))
10183     return DAG.getNode(ABDOpcode, SDLoc(N), VT, Op0, Op1);
10184 
10185   return SDValue();
10186 }
10187 
visitABS(SDNode * N)10188 SDValue DAGCombiner::visitABS(SDNode *N) {
10189   SDValue N0 = N->getOperand(0);
10190   EVT VT = N->getValueType(0);
10191 
10192   // fold (abs c1) -> c2
10193   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10194     return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0);
10195   // fold (abs (abs x)) -> (abs x)
10196   if (N0.getOpcode() == ISD::ABS)
10197     return N0;
10198   // fold (abs x) -> x iff not-negative
10199   if (DAG.SignBitIsZero(N0))
10200     return N0;
10201 
10202   if (SDValue ABD = foldABSToABD(N))
10203     return ABD;
10204 
10205   // fold (abs (sign_extend_inreg x)) -> (zero_extend (abs (truncate x)))
10206   // iff zero_extend/truncate are free.
10207   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
10208     EVT ExtVT = cast<VTSDNode>(N0.getOperand(1))->getVT();
10209     if (TLI.isTruncateFree(VT, ExtVT) && TLI.isZExtFree(ExtVT, VT) &&
10210         TLI.isTypeDesirableForOp(ISD::ABS, ExtVT) &&
10211         hasOperation(ISD::ABS, ExtVT)) {
10212       SDLoc DL(N);
10213       return DAG.getNode(
10214           ISD::ZERO_EXTEND, DL, VT,
10215           DAG.getNode(ISD::ABS, DL, ExtVT,
10216                       DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N0.getOperand(0))));
10217     }
10218   }
10219 
10220   return SDValue();
10221 }
10222 
visitBSWAP(SDNode * N)10223 SDValue DAGCombiner::visitBSWAP(SDNode *N) {
10224   SDValue N0 = N->getOperand(0);
10225   EVT VT = N->getValueType(0);
10226   SDLoc DL(N);
10227 
10228   // fold (bswap c1) -> c2
10229   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10230     return DAG.getNode(ISD::BSWAP, DL, VT, N0);
10231   // fold (bswap (bswap x)) -> x
10232   if (N0.getOpcode() == ISD::BSWAP)
10233     return N0.getOperand(0);
10234 
10235   // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
10236   // isn't supported, it will be expanded to bswap followed by a manual reversal
10237   // of bits in each byte. By placing bswaps before bitreverse, we can remove
10238   // the two bswaps if the bitreverse gets expanded.
10239   if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
10240     SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
10241     return DAG.getNode(ISD::BITREVERSE, DL, VT, BSwap);
10242   }
10243 
10244   // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
10245   // iff x >= bw/2 (i.e. lower half is known zero)
10246   unsigned BW = VT.getScalarSizeInBits();
10247   if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
10248     auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
10249     EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), BW / 2);
10250     if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
10251         ShAmt->getZExtValue() >= (BW / 2) &&
10252         (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(HalfVT) &&
10253         TLI.isTruncateFree(VT, HalfVT) &&
10254         (!LegalOperations || hasOperation(ISD::BSWAP, HalfVT))) {
10255       SDValue Res = N0.getOperand(0);
10256       if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
10257         Res = DAG.getNode(ISD::SHL, DL, VT, Res,
10258                           DAG.getConstant(NewShAmt, DL, getShiftAmountTy(VT)));
10259       Res = DAG.getZExtOrTrunc(Res, DL, HalfVT);
10260       Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res);
10261       return DAG.getZExtOrTrunc(Res, DL, VT);
10262     }
10263   }
10264 
10265   // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
10266   // inverse-shift-of-bswap:
10267   // bswap (X u<< C) --> (bswap X) u>> C
10268   // bswap (X u>> C) --> (bswap X) u<< C
10269   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
10270       N0.hasOneUse()) {
10271     auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
10272     if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
10273         ShAmt->getZExtValue() % 8 == 0) {
10274       SDValue NewSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
10275       unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
10276       return DAG.getNode(InverseShift, DL, VT, NewSwap, N0.getOperand(1));
10277     }
10278   }
10279 
10280   return SDValue();
10281 }
10282 
visitBITREVERSE(SDNode * N)10283 SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
10284   SDValue N0 = N->getOperand(0);
10285   EVT VT = N->getValueType(0);
10286 
10287   // fold (bitreverse c1) -> c2
10288   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10289     return DAG.getNode(ISD::BITREVERSE, SDLoc(N), VT, N0);
10290   // fold (bitreverse (bitreverse x)) -> x
10291   if (N0.getOpcode() == ISD::BITREVERSE)
10292     return N0.getOperand(0);
10293   return SDValue();
10294 }
10295 
visitCTLZ(SDNode * N)10296 SDValue DAGCombiner::visitCTLZ(SDNode *N) {
10297   SDValue N0 = N->getOperand(0);
10298   EVT VT = N->getValueType(0);
10299 
10300   // fold (ctlz c1) -> c2
10301   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10302     return DAG.getNode(ISD::CTLZ, SDLoc(N), VT, N0);
10303 
10304   // If the value is known never to be zero, switch to the undef version.
10305   if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT)) {
10306     if (DAG.isKnownNeverZero(N0))
10307       return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10308   }
10309 
10310   return SDValue();
10311 }
10312 
visitCTLZ_ZERO_UNDEF(SDNode * N)10313 SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
10314   SDValue N0 = N->getOperand(0);
10315   EVT VT = N->getValueType(0);
10316 
10317   // fold (ctlz_zero_undef c1) -> c2
10318   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10319     return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10320   return SDValue();
10321 }
10322 
visitCTTZ(SDNode * N)10323 SDValue DAGCombiner::visitCTTZ(SDNode *N) {
10324   SDValue N0 = N->getOperand(0);
10325   EVT VT = N->getValueType(0);
10326 
10327   // fold (cttz c1) -> c2
10328   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10329     return DAG.getNode(ISD::CTTZ, SDLoc(N), VT, N0);
10330 
10331   // If the value is known never to be zero, switch to the undef version.
10332   if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT)) {
10333     if (DAG.isKnownNeverZero(N0))
10334       return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10335   }
10336 
10337   return SDValue();
10338 }
10339 
visitCTTZ_ZERO_UNDEF(SDNode * N)10340 SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
10341   SDValue N0 = N->getOperand(0);
10342   EVT VT = N->getValueType(0);
10343 
10344   // fold (cttz_zero_undef c1) -> c2
10345   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10346     return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10347   return SDValue();
10348 }
10349 
visitCTPOP(SDNode * N)10350 SDValue DAGCombiner::visitCTPOP(SDNode *N) {
10351   SDValue N0 = N->getOperand(0);
10352   EVT VT = N->getValueType(0);
10353 
10354   // fold (ctpop c1) -> c2
10355   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10356     return DAG.getNode(ISD::CTPOP, SDLoc(N), VT, N0);
10357   return SDValue();
10358 }
10359 
10360 // FIXME: This should be checking for no signed zeros on individual operands, as
10361 // well as no nans.
isLegalToCombineMinNumMaxNum(SelectionDAG & DAG,SDValue LHS,SDValue RHS,const TargetLowering & TLI)10362 static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
10363                                          SDValue RHS,
10364                                          const TargetLowering &TLI) {
10365   const TargetOptions &Options = DAG.getTarget().Options;
10366   EVT VT = LHS.getValueType();
10367 
10368   return Options.NoSignedZerosFPMath && VT.isFloatingPoint() &&
10369          TLI.isProfitableToCombineMinNumMaxNum(VT) &&
10370          DAG.isKnownNeverNaN(LHS) && DAG.isKnownNeverNaN(RHS);
10371 }
10372 
combineMinNumMaxNumImpl(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const TargetLowering & TLI,SelectionDAG & DAG)10373 static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
10374                                        SDValue RHS, SDValue True, SDValue False,
10375                                        ISD::CondCode CC,
10376                                        const TargetLowering &TLI,
10377                                        SelectionDAG &DAG) {
10378   EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
10379   switch (CC) {
10380   case ISD::SETOLT:
10381   case ISD::SETOLE:
10382   case ISD::SETLT:
10383   case ISD::SETLE:
10384   case ISD::SETULT:
10385   case ISD::SETULE: {
10386     // Since it's known never nan to get here already, either fminnum or
10387     // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
10388     // expanded in terms of it.
10389     unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
10390     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
10391       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
10392 
10393     unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
10394     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
10395       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
10396     return SDValue();
10397   }
10398   case ISD::SETOGT:
10399   case ISD::SETOGE:
10400   case ISD::SETGT:
10401   case ISD::SETGE:
10402   case ISD::SETUGT:
10403   case ISD::SETUGE: {
10404     unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
10405     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
10406       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
10407 
10408     unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
10409     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
10410       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
10411     return SDValue();
10412   }
10413   default:
10414     return SDValue();
10415   }
10416 }
10417 
10418 /// Generate Min/Max node
combineMinNumMaxNum(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC)10419 SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
10420                                          SDValue RHS, SDValue True,
10421                                          SDValue False, ISD::CondCode CC) {
10422   if ((LHS == True && RHS == False) || (LHS == False && RHS == True))
10423     return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG);
10424 
10425   // If we can't directly match this, try to see if we can pull an fneg out of
10426   // the select.
10427   SDValue NegTrue = TLI.getCheaperOrNeutralNegatedExpression(
10428       True, DAG, LegalOperations, ForCodeSize);
10429   if (!NegTrue)
10430     return SDValue();
10431 
10432   HandleSDNode NegTrueHandle(NegTrue);
10433 
10434   // Try to unfold an fneg from the select if we are comparing the negated
10435   // constant.
10436   //
10437   // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K))
10438   //
10439   // TODO: Handle fabs
10440   if (LHS == NegTrue) {
10441     // If we can't directly match this, try to see if we can pull an fneg out of
10442     // the select.
10443     SDValue NegRHS = TLI.getCheaperOrNeutralNegatedExpression(
10444         RHS, DAG, LegalOperations, ForCodeSize);
10445     if (NegRHS) {
10446       HandleSDNode NegRHSHandle(NegRHS);
10447       if (NegRHS == False) {
10448         SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, NegTrue,
10449                                                    False, CC, TLI, DAG);
10450         if (Combined)
10451           return DAG.getNode(ISD::FNEG, DL, VT, Combined);
10452       }
10453     }
10454   }
10455 
10456   return SDValue();
10457 }
10458 
10459 /// If a (v)select has a condition value that is a sign-bit test, try to smear
10460 /// the condition operand sign-bit across the value width and use it as a mask.
foldSelectOfConstantsUsingSra(SDNode * N,SelectionDAG & DAG)10461 static SDValue foldSelectOfConstantsUsingSra(SDNode *N, SelectionDAG &DAG) {
10462   SDValue Cond = N->getOperand(0);
10463   SDValue C1 = N->getOperand(1);
10464   SDValue C2 = N->getOperand(2);
10465   if (!isConstantOrConstantVector(C1) || !isConstantOrConstantVector(C2))
10466     return SDValue();
10467 
10468   EVT VT = N->getValueType(0);
10469   if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
10470       VT != Cond.getOperand(0).getValueType())
10471     return SDValue();
10472 
10473   // The inverted-condition + commuted-select variants of these patterns are
10474   // canonicalized to these forms in IR.
10475   SDValue X = Cond.getOperand(0);
10476   SDValue CondC = Cond.getOperand(1);
10477   ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
10478   if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CondC) &&
10479       isAllOnesOrAllOnesSplat(C2)) {
10480     // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
10481     SDLoc DL(N);
10482     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
10483     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
10484     return DAG.getNode(ISD::OR, DL, VT, Sra, C1);
10485   }
10486   if (CC == ISD::SETLT && isNullOrNullSplat(CondC) && isNullOrNullSplat(C2)) {
10487     // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
10488     SDLoc DL(N);
10489     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
10490     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
10491     return DAG.getNode(ISD::AND, DL, VT, Sra, C1);
10492   }
10493   return SDValue();
10494 }
10495 
shouldConvertSelectOfConstantsToMath(const SDValue & Cond,EVT VT,const TargetLowering & TLI)10496 static bool shouldConvertSelectOfConstantsToMath(const SDValue &Cond, EVT VT,
10497                                                  const TargetLowering &TLI) {
10498   if (!TLI.convertSelectOfConstantsToMath(VT))
10499     return false;
10500 
10501   if (Cond.getOpcode() != ISD::SETCC || !Cond->hasOneUse())
10502     return true;
10503   if (!TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))
10504     return true;
10505 
10506   ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
10507   if (CC == ISD::SETLT && isNullOrNullSplat(Cond.getOperand(1)))
10508     return true;
10509   if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond.getOperand(1)))
10510     return true;
10511 
10512   return false;
10513 }
10514 
foldSelectOfConstants(SDNode * N)10515 SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
10516   SDValue Cond = N->getOperand(0);
10517   SDValue N1 = N->getOperand(1);
10518   SDValue N2 = N->getOperand(2);
10519   EVT VT = N->getValueType(0);
10520   EVT CondVT = Cond.getValueType();
10521   SDLoc DL(N);
10522 
10523   if (!VT.isInteger())
10524     return SDValue();
10525 
10526   auto *C1 = dyn_cast<ConstantSDNode>(N1);
10527   auto *C2 = dyn_cast<ConstantSDNode>(N2);
10528   if (!C1 || !C2)
10529     return SDValue();
10530 
10531   if (CondVT != MVT::i1 || LegalOperations) {
10532     // fold (select Cond, 0, 1) -> (xor Cond, 1)
10533     // We can't do this reliably if integer based booleans have different contents
10534     // to floating point based booleans. This is because we can't tell whether we
10535     // have an integer-based boolean or a floating-point-based boolean unless we
10536     // can find the SETCC that produced it and inspect its operands. This is
10537     // fairly easy if C is the SETCC node, but it can potentially be
10538     // undiscoverable (or not reasonably discoverable). For example, it could be
10539     // in another basic block or it could require searching a complicated
10540     // expression.
10541     if (CondVT.isInteger() &&
10542         TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
10543             TargetLowering::ZeroOrOneBooleanContent &&
10544         TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
10545             TargetLowering::ZeroOrOneBooleanContent &&
10546         C1->isZero() && C2->isOne()) {
10547       SDValue NotCond =
10548           DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
10549       if (VT.bitsEq(CondVT))
10550         return NotCond;
10551       return DAG.getZExtOrTrunc(NotCond, DL, VT);
10552     }
10553 
10554     return SDValue();
10555   }
10556 
10557   // Only do this before legalization to avoid conflicting with target-specific
10558   // transforms in the other direction (create a select from a zext/sext). There
10559   // is also a target-independent combine here in DAGCombiner in the other
10560   // direction for (select Cond, -1, 0) when the condition is not i1.
10561   assert(CondVT == MVT::i1 && !LegalOperations);
10562 
10563   // select Cond, 1, 0 --> zext (Cond)
10564   if (C1->isOne() && C2->isZero())
10565     return DAG.getZExtOrTrunc(Cond, DL, VT);
10566 
10567   // select Cond, -1, 0 --> sext (Cond)
10568   if (C1->isAllOnes() && C2->isZero())
10569     return DAG.getSExtOrTrunc(Cond, DL, VT);
10570 
10571   // select Cond, 0, 1 --> zext (!Cond)
10572   if (C1->isZero() && C2->isOne()) {
10573     SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
10574     NotCond = DAG.getZExtOrTrunc(NotCond, DL, VT);
10575     return NotCond;
10576   }
10577 
10578   // select Cond, 0, -1 --> sext (!Cond)
10579   if (C1->isZero() && C2->isAllOnes()) {
10580     SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
10581     NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
10582     return NotCond;
10583   }
10584 
10585   // Use a target hook because some targets may prefer to transform in the
10586   // other direction.
10587   if (!shouldConvertSelectOfConstantsToMath(Cond, VT, TLI))
10588     return SDValue();
10589 
10590   // For any constants that differ by 1, we can transform the select into
10591   // an extend and add.
10592   const APInt &C1Val = C1->getAPIntValue();
10593   const APInt &C2Val = C2->getAPIntValue();
10594 
10595   // select Cond, C1, C1-1 --> add (zext Cond), C1-1
10596   if (C1Val - 1 == C2Val) {
10597     Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
10598     return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
10599   }
10600 
10601   // select Cond, C1, C1+1 --> add (sext Cond), C1+1
10602   if (C1Val + 1 == C2Val) {
10603     Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
10604     return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
10605   }
10606 
10607   // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
10608   if (C1Val.isPowerOf2() && C2Val.isZero()) {
10609     Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
10610     SDValue ShAmtC =
10611         DAG.getShiftAmountConstant(C1Val.exactLogBase2(), VT, DL);
10612     return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC);
10613   }
10614 
10615   // select Cond, -1, C --> or (sext Cond), C
10616   if (C1->isAllOnes()) {
10617     Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
10618     return DAG.getNode(ISD::OR, DL, VT, Cond, N2);
10619   }
10620 
10621   // select Cond, C, -1 --> or (sext (not Cond)), C
10622   if (C2->isAllOnes()) {
10623     SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
10624     NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
10625     return DAG.getNode(ISD::OR, DL, VT, NotCond, N1);
10626   }
10627 
10628   if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
10629     return V;
10630 
10631   return SDValue();
10632 }
10633 
foldBoolSelectToLogic(SDNode * N,SelectionDAG & DAG)10634 static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) {
10635   assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT) &&
10636          "Expected a (v)select");
10637   SDValue Cond = N->getOperand(0);
10638   SDValue T = N->getOperand(1), F = N->getOperand(2);
10639   EVT VT = N->getValueType(0);
10640   if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
10641     return SDValue();
10642 
10643   // select Cond, Cond, F --> or Cond, F
10644   // select Cond, 1, F    --> or Cond, F
10645   if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
10646     return DAG.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
10647 
10648   // select Cond, T, Cond --> and Cond, T
10649   // select Cond, T, 0    --> and Cond, T
10650   if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
10651     return DAG.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
10652 
10653   // select Cond, T, 1 --> or (not Cond), T
10654   if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
10655     SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
10656     return DAG.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
10657   }
10658 
10659   // select Cond, 0, F --> and (not Cond), F
10660   if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
10661     SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
10662     return DAG.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
10663   }
10664 
10665   return SDValue();
10666 }
10667 
foldVSelectToSignBitSplatMask(SDNode * N,SelectionDAG & DAG)10668 static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
10669   SDValue N0 = N->getOperand(0);
10670   SDValue N1 = N->getOperand(1);
10671   SDValue N2 = N->getOperand(2);
10672   EVT VT = N->getValueType(0);
10673   if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
10674     return SDValue();
10675 
10676   SDValue Cond0 = N0.getOperand(0);
10677   SDValue Cond1 = N0.getOperand(1);
10678   ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
10679   if (VT != Cond0.getValueType())
10680     return SDValue();
10681 
10682   // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
10683   // compare is inverted from that pattern ("Cond0 s> -1").
10684   if (CC == ISD::SETLT && isNullOrNullSplat(Cond1))
10685     ; // This is the pattern we are looking for.
10686   else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond1))
10687     std::swap(N1, N2);
10688   else
10689     return SDValue();
10690 
10691   // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & N1
10692   if (isNullOrNullSplat(N2)) {
10693     SDLoc DL(N);
10694     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
10695     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
10696     return DAG.getNode(ISD::AND, DL, VT, Sra, N1);
10697   }
10698 
10699   // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | N2
10700   if (isAllOnesOrAllOnesSplat(N1)) {
10701     SDLoc DL(N);
10702     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
10703     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
10704     return DAG.getNode(ISD::OR, DL, VT, Sra, N2);
10705   }
10706 
10707   // If we have to invert the sign bit mask, only do that transform if the
10708   // target has a bitwise 'and not' instruction (the invert is free).
10709   // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & N2
10710   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
10711   if (isNullOrNullSplat(N1) && TLI.hasAndNot(N1)) {
10712     SDLoc DL(N);
10713     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
10714     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
10715     SDValue Not = DAG.getNOT(DL, Sra, VT);
10716     return DAG.getNode(ISD::AND, DL, VT, Not, N2);
10717   }
10718 
10719   // TODO: There's another pattern in this family, but it may require
10720   //       implementing hasOrNot() to check for profitability:
10721   //       (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | N2
10722 
10723   return SDValue();
10724 }
10725 
visitSELECT(SDNode * N)10726 SDValue DAGCombiner::visitSELECT(SDNode *N) {
10727   SDValue N0 = N->getOperand(0);
10728   SDValue N1 = N->getOperand(1);
10729   SDValue N2 = N->getOperand(2);
10730   EVT VT = N->getValueType(0);
10731   EVT VT0 = N0.getValueType();
10732   SDLoc DL(N);
10733   SDNodeFlags Flags = N->getFlags();
10734 
10735   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
10736     return V;
10737 
10738   if (SDValue V = foldBoolSelectToLogic(N, DAG))
10739     return V;
10740 
10741   // select (not Cond), N1, N2 -> select Cond, N2, N1
10742   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
10743     SDValue SelectOp = DAG.getSelect(DL, VT, F, N2, N1);
10744     SelectOp->setFlags(Flags);
10745     return SelectOp;
10746   }
10747 
10748   if (SDValue V = foldSelectOfConstants(N))
10749     return V;
10750 
10751   // If we can fold this based on the true/false value, do so.
10752   if (SimplifySelectOps(N, N1, N2))
10753     return SDValue(N, 0); // Don't revisit N.
10754 
10755   if (VT0 == MVT::i1) {
10756     // The code in this block deals with the following 2 equivalences:
10757     //    select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
10758     //    select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
10759     // The target can specify its preferred form with the
10760     // shouldNormalizeToSelectSequence() callback. However we always transform
10761     // to the right anyway if we find the inner select exists in the DAG anyway
10762     // and we always transform to the left side if we know that we can further
10763     // optimize the combination of the conditions.
10764     bool normalizeToSequence =
10765         TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
10766     // select (and Cond0, Cond1), X, Y
10767     //   -> select Cond0, (select Cond1, X, Y), Y
10768     if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
10769       SDValue Cond0 = N0->getOperand(0);
10770       SDValue Cond1 = N0->getOperand(1);
10771       SDValue InnerSelect =
10772           DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2, Flags);
10773       if (normalizeToSequence || !InnerSelect.use_empty())
10774         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
10775                            InnerSelect, N2, Flags);
10776       // Cleanup on failure.
10777       if (InnerSelect.use_empty())
10778         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
10779     }
10780     // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
10781     if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
10782       SDValue Cond0 = N0->getOperand(0);
10783       SDValue Cond1 = N0->getOperand(1);
10784       SDValue InnerSelect = DAG.getNode(ISD::SELECT, DL, N1.getValueType(),
10785                                         Cond1, N1, N2, Flags);
10786       if (normalizeToSequence || !InnerSelect.use_empty())
10787         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
10788                            InnerSelect, Flags);
10789       // Cleanup on failure.
10790       if (InnerSelect.use_empty())
10791         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
10792     }
10793 
10794     // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
10795     if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
10796       SDValue N1_0 = N1->getOperand(0);
10797       SDValue N1_1 = N1->getOperand(1);
10798       SDValue N1_2 = N1->getOperand(2);
10799       if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
10800         // Create the actual and node if we can generate good code for it.
10801         if (!normalizeToSequence) {
10802           SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
10803           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1,
10804                              N2, Flags);
10805         }
10806         // Otherwise see if we can optimize the "and" to a better pattern.
10807         if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
10808           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
10809                              N2, Flags);
10810         }
10811       }
10812     }
10813     // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
10814     if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
10815       SDValue N2_0 = N2->getOperand(0);
10816       SDValue N2_1 = N2->getOperand(1);
10817       SDValue N2_2 = N2->getOperand(2);
10818       if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
10819         // Create the actual or node if we can generate good code for it.
10820         if (!normalizeToSequence) {
10821           SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
10822           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1,
10823                              N2_2, Flags);
10824         }
10825         // Otherwise see if we can optimize to a better pattern.
10826         if (SDValue Combined = visitORLike(N0, N2_0, N))
10827           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
10828                              N2_2, Flags);
10829       }
10830     }
10831   }
10832 
10833   // Fold selects based on a setcc into other things, such as min/max/abs.
10834   if (N0.getOpcode() == ISD::SETCC) {
10835     SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
10836     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
10837 
10838     // select (fcmp lt x, y), x, y -> fminnum x, y
10839     // select (fcmp gt x, y), x, y -> fmaxnum x, y
10840     //
10841     // This is OK if we don't care what happens if either operand is a NaN.
10842     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, TLI))
10843       if (SDValue FMinMax =
10844               combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2, CC))
10845         return FMinMax;
10846 
10847     // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
10848     // This is conservatively limited to pre-legal-operations to give targets
10849     // a chance to reverse the transform if they want to do that. Also, it is
10850     // unlikely that the pattern would be formed late, so it's probably not
10851     // worth going through the other checks.
10852     if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
10853         CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
10854         N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
10855       auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
10856       auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
10857       if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
10858         // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
10859         // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
10860         //
10861         // The IR equivalent of this transform would have this form:
10862         //   %a = add %x, C
10863         //   %c = icmp ugt %x, ~C
10864         //   %r = select %c, -1, %a
10865         //   =>
10866         //   %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
10867         //   %u0 = extractvalue %u, 0
10868         //   %u1 = extractvalue %u, 1
10869         //   %r = select %u1, -1, %u0
10870         SDVTList VTs = DAG.getVTList(VT, VT0);
10871         SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
10872         return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
10873       }
10874     }
10875 
10876     if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
10877         (!LegalOperations &&
10878          TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) {
10879       // Any flags available in a select/setcc fold will be on the setcc as they
10880       // migrated from fcmp
10881       Flags = N0->getFlags();
10882       SDValue SelectNode = DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1,
10883                                        N2, N0.getOperand(2));
10884       SelectNode->setFlags(Flags);
10885       return SelectNode;
10886     }
10887 
10888     if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
10889       return NewSel;
10890   }
10891 
10892   if (!VT.isVector())
10893     if (SDValue BinOp = foldSelectOfBinops(N))
10894       return BinOp;
10895 
10896   return SDValue();
10897 }
10898 
10899 // This function assumes all the vselect's arguments are CONCAT_VECTOR
10900 // nodes and that the condition is a BV of ConstantSDNodes (or undefs).
ConvertSelectToConcatVector(SDNode * N,SelectionDAG & DAG)10901 static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
10902   SDLoc DL(N);
10903   SDValue Cond = N->getOperand(0);
10904   SDValue LHS = N->getOperand(1);
10905   SDValue RHS = N->getOperand(2);
10906   EVT VT = N->getValueType(0);
10907   int NumElems = VT.getVectorNumElements();
10908   assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
10909          RHS.getOpcode() == ISD::CONCAT_VECTORS &&
10910          Cond.getOpcode() == ISD::BUILD_VECTOR);
10911 
10912   // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
10913   // binary ones here.
10914   if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
10915     return SDValue();
10916 
10917   // We're sure we have an even number of elements due to the
10918   // concat_vectors we have as arguments to vselect.
10919   // Skip BV elements until we find one that's not an UNDEF
10920   // After we find an UNDEF element, keep looping until we get to half the
10921   // length of the BV and see if all the non-undef nodes are the same.
10922   ConstantSDNode *BottomHalf = nullptr;
10923   for (int i = 0; i < NumElems / 2; ++i) {
10924     if (Cond->getOperand(i)->isUndef())
10925       continue;
10926 
10927     if (BottomHalf == nullptr)
10928       BottomHalf = cast<ConstantSDNode>(Cond.getOperand(i));
10929     else if (Cond->getOperand(i).getNode() != BottomHalf)
10930       return SDValue();
10931   }
10932 
10933   // Do the same for the second half of the BuildVector
10934   ConstantSDNode *TopHalf = nullptr;
10935   for (int i = NumElems / 2; i < NumElems; ++i) {
10936     if (Cond->getOperand(i)->isUndef())
10937       continue;
10938 
10939     if (TopHalf == nullptr)
10940       TopHalf = cast<ConstantSDNode>(Cond.getOperand(i));
10941     else if (Cond->getOperand(i).getNode() != TopHalf)
10942       return SDValue();
10943   }
10944 
10945   assert(TopHalf && BottomHalf &&
10946          "One half of the selector was all UNDEFs and the other was all the "
10947          "same value. This should have been addressed before this function.");
10948   return DAG.getNode(
10949       ISD::CONCAT_VECTORS, DL, VT,
10950       BottomHalf->isZero() ? RHS->getOperand(0) : LHS->getOperand(0),
10951       TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1));
10952 }
10953 
refineUniformBase(SDValue & BasePtr,SDValue & Index,bool IndexIsScaled,SelectionDAG & DAG,const SDLoc & DL)10954 bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
10955                        SelectionDAG &DAG, const SDLoc &DL) {
10956   if (Index.getOpcode() != ISD::ADD)
10957     return false;
10958 
10959   // Only perform the transformation when existing operands can be reused.
10960   if (IndexIsScaled)
10961     return false;
10962 
10963   if (!isNullConstant(BasePtr) && !Index.hasOneUse())
10964     return false;
10965 
10966   EVT VT = BasePtr.getValueType();
10967   if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0));
10968       SplatVal && SplatVal.getValueType() == VT) {
10969     if (isNullConstant(BasePtr))
10970       BasePtr = SplatVal;
10971     else
10972       BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
10973     Index = Index.getOperand(1);
10974     return true;
10975   }
10976   if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(1));
10977       SplatVal && SplatVal.getValueType() == VT) {
10978     if (isNullConstant(BasePtr))
10979       BasePtr = SplatVal;
10980     else
10981       BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
10982     Index = Index.getOperand(0);
10983     return true;
10984   }
10985   return false;
10986 }
10987 
10988 // Fold sext/zext of index into index type.
refineIndexType(SDValue & Index,ISD::MemIndexType & IndexType,EVT DataVT,SelectionDAG & DAG)10989 bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
10990                      SelectionDAG &DAG) {
10991   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
10992 
10993   // It's always safe to look through zero extends.
10994   if (Index.getOpcode() == ISD::ZERO_EXTEND) {
10995     SDValue Op = Index.getOperand(0);
10996     if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType(), DataVT)) {
10997       IndexType = ISD::UNSIGNED_SCALED;
10998       Index = Op;
10999       return true;
11000     }
11001     if (ISD::isIndexTypeSigned(IndexType)) {
11002       IndexType = ISD::UNSIGNED_SCALED;
11003       return true;
11004     }
11005   }
11006 
11007   // It's only safe to look through sign extends when Index is signed.
11008   if (Index.getOpcode() == ISD::SIGN_EXTEND &&
11009       ISD::isIndexTypeSigned(IndexType)) {
11010     SDValue Op = Index.getOperand(0);
11011     if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType(), DataVT)) {
11012       Index = Op;
11013       return true;
11014     }
11015   }
11016 
11017   return false;
11018 }
11019 
visitVPSCATTER(SDNode * N)11020 SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
11021   VPScatterSDNode *MSC = cast<VPScatterSDNode>(N);
11022   SDValue Mask = MSC->getMask();
11023   SDValue Chain = MSC->getChain();
11024   SDValue Index = MSC->getIndex();
11025   SDValue Scale = MSC->getScale();
11026   SDValue StoreVal = MSC->getValue();
11027   SDValue BasePtr = MSC->getBasePtr();
11028   SDValue VL = MSC->getVectorLength();
11029   ISD::MemIndexType IndexType = MSC->getIndexType();
11030   SDLoc DL(N);
11031 
11032   // Zap scatters with a zero mask.
11033   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11034     return Chain;
11035 
11036   if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
11037     SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11038     return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11039                             DL, Ops, MSC->getMemOperand(), IndexType);
11040   }
11041 
11042   if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
11043     SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11044     return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11045                             DL, Ops, MSC->getMemOperand(), IndexType);
11046   }
11047 
11048   return SDValue();
11049 }
11050 
visitMSCATTER(SDNode * N)11051 SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
11052   MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
11053   SDValue Mask = MSC->getMask();
11054   SDValue Chain = MSC->getChain();
11055   SDValue Index = MSC->getIndex();
11056   SDValue Scale = MSC->getScale();
11057   SDValue StoreVal = MSC->getValue();
11058   SDValue BasePtr = MSC->getBasePtr();
11059   ISD::MemIndexType IndexType = MSC->getIndexType();
11060   SDLoc DL(N);
11061 
11062   // Zap scatters with a zero mask.
11063   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11064     return Chain;
11065 
11066   if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
11067     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11068     return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11069                                 DL, Ops, MSC->getMemOperand(), IndexType,
11070                                 MSC->isTruncatingStore());
11071   }
11072 
11073   if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
11074     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11075     return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11076                                 DL, Ops, MSC->getMemOperand(), IndexType,
11077                                 MSC->isTruncatingStore());
11078   }
11079 
11080   return SDValue();
11081 }
11082 
visitMSTORE(SDNode * N)11083 SDValue DAGCombiner::visitMSTORE(SDNode *N) {
11084   MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
11085   SDValue Mask = MST->getMask();
11086   SDValue Chain = MST->getChain();
11087   SDValue Value = MST->getValue();
11088   SDValue Ptr = MST->getBasePtr();
11089   SDLoc DL(N);
11090 
11091   // Zap masked stores with a zero mask.
11092   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11093     return Chain;
11094 
11095   // If this is a masked load with an all ones mask, we can use a unmasked load.
11096   // FIXME: Can we do this for indexed, compressing, or truncating stores?
11097   if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MST->isUnindexed() &&
11098       !MST->isCompressingStore() && !MST->isTruncatingStore())
11099     return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(),
11100                         MST->getBasePtr(), MST->getPointerInfo(),
11101                         MST->getOriginalAlign(), MachineMemOperand::MOStore,
11102                         MST->getAAInfo());
11103 
11104   // Try transforming N to an indexed store.
11105   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
11106     return SDValue(N, 0);
11107 
11108   if (MST->isTruncatingStore() && MST->isUnindexed() &&
11109       Value.getValueType().isInteger() &&
11110       (!isa<ConstantSDNode>(Value) ||
11111        !cast<ConstantSDNode>(Value)->isOpaque())) {
11112     APInt TruncDemandedBits =
11113         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
11114                              MST->getMemoryVT().getScalarSizeInBits());
11115 
11116     // See if we can simplify the operation with
11117     // SimplifyDemandedBits, which only works if the value has a single use.
11118     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
11119       // Re-visit the store if anything changed and the store hasn't been merged
11120       // with another node (N is deleted) SimplifyDemandedBits will add Value's
11121       // node back to the worklist if necessary, but we also need to re-visit
11122       // the Store node itself.
11123       if (N->getOpcode() != ISD::DELETED_NODE)
11124         AddToWorklist(N);
11125       return SDValue(N, 0);
11126     }
11127   }
11128 
11129   // If this is a TRUNC followed by a masked store, fold this into a masked
11130   // truncating store.  We can do this even if this is already a masked
11131   // truncstore.
11132   // TODO: Try combine to masked compress store if possiable.
11133   if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
11134       MST->isUnindexed() && !MST->isCompressingStore() &&
11135       TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
11136                                MST->getMemoryVT(), LegalOperations)) {
11137     auto Mask = TLI.promoteTargetBoolean(DAG, MST->getMask(),
11138                                          Value.getOperand(0).getValueType());
11139     return DAG.getMaskedStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
11140                               MST->getOffset(), Mask, MST->getMemoryVT(),
11141                               MST->getMemOperand(), MST->getAddressingMode(),
11142                               /*IsTruncating=*/true);
11143   }
11144 
11145   return SDValue();
11146 }
11147 
visitVPGATHER(SDNode * N)11148 SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
11149   VPGatherSDNode *MGT = cast<VPGatherSDNode>(N);
11150   SDValue Mask = MGT->getMask();
11151   SDValue Chain = MGT->getChain();
11152   SDValue Index = MGT->getIndex();
11153   SDValue Scale = MGT->getScale();
11154   SDValue BasePtr = MGT->getBasePtr();
11155   SDValue VL = MGT->getVectorLength();
11156   ISD::MemIndexType IndexType = MGT->getIndexType();
11157   SDLoc DL(N);
11158 
11159   if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
11160     SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
11161     return DAG.getGatherVP(
11162         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
11163         Ops, MGT->getMemOperand(), IndexType);
11164   }
11165 
11166   if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
11167     SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
11168     return DAG.getGatherVP(
11169         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
11170         Ops, MGT->getMemOperand(), IndexType);
11171   }
11172 
11173   return SDValue();
11174 }
11175 
visitMGATHER(SDNode * N)11176 SDValue DAGCombiner::visitMGATHER(SDNode *N) {
11177   MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
11178   SDValue Mask = MGT->getMask();
11179   SDValue Chain = MGT->getChain();
11180   SDValue Index = MGT->getIndex();
11181   SDValue Scale = MGT->getScale();
11182   SDValue PassThru = MGT->getPassThru();
11183   SDValue BasePtr = MGT->getBasePtr();
11184   ISD::MemIndexType IndexType = MGT->getIndexType();
11185   SDLoc DL(N);
11186 
11187   // Zap gathers with a zero mask.
11188   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11189     return CombineTo(N, PassThru, MGT->getChain());
11190 
11191   if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
11192     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
11193     return DAG.getMaskedGather(
11194         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
11195         Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
11196   }
11197 
11198   if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
11199     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
11200     return DAG.getMaskedGather(
11201         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
11202         Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
11203   }
11204 
11205   return SDValue();
11206 }
11207 
visitMLOAD(SDNode * N)11208 SDValue DAGCombiner::visitMLOAD(SDNode *N) {
11209   MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
11210   SDValue Mask = MLD->getMask();
11211   SDLoc DL(N);
11212 
11213   // Zap masked loads with a zero mask.
11214   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11215     return CombineTo(N, MLD->getPassThru(), MLD->getChain());
11216 
11217   // If this is a masked load with an all ones mask, we can use a unmasked load.
11218   // FIXME: Can we do this for indexed, expanding, or extending loads?
11219   if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MLD->isUnindexed() &&
11220       !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
11221     SDValue NewLd = DAG.getLoad(
11222         N->getValueType(0), SDLoc(N), MLD->getChain(), MLD->getBasePtr(),
11223         MLD->getPointerInfo(), MLD->getOriginalAlign(),
11224         MachineMemOperand::MOLoad, MLD->getAAInfo(), MLD->getRanges());
11225     return CombineTo(N, NewLd, NewLd.getValue(1));
11226   }
11227 
11228   // Try transforming N to an indexed load.
11229   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
11230     return SDValue(N, 0);
11231 
11232   return SDValue();
11233 }
11234 
11235 /// A vector select of 2 constant vectors can be simplified to math/logic to
11236 /// avoid a variable select instruction and possibly avoid constant loads.
foldVSelectOfConstants(SDNode * N)11237 SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
11238   SDValue Cond = N->getOperand(0);
11239   SDValue N1 = N->getOperand(1);
11240   SDValue N2 = N->getOperand(2);
11241   EVT VT = N->getValueType(0);
11242   if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
11243       !shouldConvertSelectOfConstantsToMath(Cond, VT, TLI) ||
11244       !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) ||
11245       !ISD::isBuildVectorOfConstantSDNodes(N2.getNode()))
11246     return SDValue();
11247 
11248   // Check if we can use the condition value to increment/decrement a single
11249   // constant value. This simplifies a select to an add and removes a constant
11250   // load/materialization from the general case.
11251   bool AllAddOne = true;
11252   bool AllSubOne = true;
11253   unsigned Elts = VT.getVectorNumElements();
11254   for (unsigned i = 0; i != Elts; ++i) {
11255     SDValue N1Elt = N1.getOperand(i);
11256     SDValue N2Elt = N2.getOperand(i);
11257     if (N1Elt.isUndef() || N2Elt.isUndef())
11258       continue;
11259     if (N1Elt.getValueType() != N2Elt.getValueType())
11260       continue;
11261 
11262     const APInt &C1 = cast<ConstantSDNode>(N1Elt)->getAPIntValue();
11263     const APInt &C2 = cast<ConstantSDNode>(N2Elt)->getAPIntValue();
11264     if (C1 != C2 + 1)
11265       AllAddOne = false;
11266     if (C1 != C2 - 1)
11267       AllSubOne = false;
11268   }
11269 
11270   // Further simplifications for the extra-special cases where the constants are
11271   // all 0 or all -1 should be implemented as folds of these patterns.
11272   SDLoc DL(N);
11273   if (AllAddOne || AllSubOne) {
11274     // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
11275     // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
11276     auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
11277     SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
11278     return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
11279   }
11280 
11281   // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
11282   APInt Pow2C;
11283   if (ISD::isConstantSplatVector(N1.getNode(), Pow2C) && Pow2C.isPowerOf2() &&
11284       isNullOrNullSplat(N2)) {
11285     SDValue ZextCond = DAG.getZExtOrTrunc(Cond, DL, VT);
11286     SDValue ShAmtC = DAG.getConstant(Pow2C.exactLogBase2(), DL, VT);
11287     return DAG.getNode(ISD::SHL, DL, VT, ZextCond, ShAmtC);
11288   }
11289 
11290   if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
11291     return V;
11292 
11293   // The general case for select-of-constants:
11294   // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
11295   // ...but that only makes sense if a vselect is slower than 2 logic ops, so
11296   // leave that to a machine-specific pass.
11297   return SDValue();
11298 }
11299 
visitVSELECT(SDNode * N)11300 SDValue DAGCombiner::visitVSELECT(SDNode *N) {
11301   SDValue N0 = N->getOperand(0);
11302   SDValue N1 = N->getOperand(1);
11303   SDValue N2 = N->getOperand(2);
11304   EVT VT = N->getValueType(0);
11305   SDLoc DL(N);
11306 
11307   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
11308     return V;
11309 
11310   if (SDValue V = foldBoolSelectToLogic(N, DAG))
11311     return V;
11312 
11313   // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
11314   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
11315     return DAG.getSelect(DL, VT, F, N2, N1);
11316 
11317   // Canonicalize integer abs.
11318   // vselect (setg[te] X,  0),  X, -X ->
11319   // vselect (setgt    X, -1),  X, -X ->
11320   // vselect (setl[te] X,  0), -X,  X ->
11321   // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
11322   if (N0.getOpcode() == ISD::SETCC) {
11323     SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
11324     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
11325     bool isAbs = false;
11326     bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
11327 
11328     if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
11329          (ISD::isBuildVectorAllOnes(RHS.getNode()) && CC == ISD::SETGT)) &&
11330         N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(1))
11331       isAbs = ISD::isBuildVectorAllZeros(N2.getOperand(0).getNode());
11332     else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
11333              N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(1))
11334       isAbs = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
11335 
11336     if (isAbs) {
11337       if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
11338         return DAG.getNode(ISD::ABS, DL, VT, LHS);
11339 
11340       SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, LHS,
11341                                   DAG.getConstant(VT.getScalarSizeInBits() - 1,
11342                                                   DL, getShiftAmountTy(VT)));
11343       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
11344       AddToWorklist(Shift.getNode());
11345       AddToWorklist(Add.getNode());
11346       return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
11347     }
11348 
11349     // vselect x, y (fcmp lt x, y) -> fminnum x, y
11350     // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
11351     //
11352     // This is OK if we don't care about what happens if either operand is a
11353     // NaN.
11354     //
11355     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, TLI)) {
11356       if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC))
11357         return FMinMax;
11358     }
11359 
11360     if (SDValue S = PerformMinMaxFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
11361       return S;
11362     if (SDValue S = PerformUMinFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
11363       return S;
11364 
11365     // If this select has a condition (setcc) with narrower operands than the
11366     // select, try to widen the compare to match the select width.
11367     // TODO: This should be extended to handle any constant.
11368     // TODO: This could be extended to handle non-loading patterns, but that
11369     //       requires thorough testing to avoid regressions.
11370     if (isNullOrNullSplat(RHS)) {
11371       EVT NarrowVT = LHS.getValueType();
11372       EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
11373       EVT SetCCVT = getSetCCResultType(LHS.getValueType());
11374       unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
11375       unsigned WideWidth = WideVT.getScalarSizeInBits();
11376       bool IsSigned = isSignedIntSetCC(CC);
11377       auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
11378       if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
11379           SetCCWidth != 1 && SetCCWidth < WideWidth &&
11380           TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
11381           TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
11382         // Both compare operands can be widened for free. The LHS can use an
11383         // extended load, and the RHS is a constant:
11384         //   vselect (ext (setcc load(X), C)), N1, N2 -->
11385         //   vselect (setcc extload(X), C'), N1, N2
11386         auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
11387         SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
11388         SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
11389         EVT WideSetCCVT = getSetCCResultType(WideVT);
11390         SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
11391         return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
11392       }
11393     }
11394 
11395     // Match VSELECTs into add with unsigned saturation.
11396     if (hasOperation(ISD::UADDSAT, VT)) {
11397       // Check if one of the arms of the VSELECT is vector with all bits set.
11398       // If it's on the left side invert the predicate to simplify logic below.
11399       SDValue Other;
11400       ISD::CondCode SatCC = CC;
11401       if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) {
11402         Other = N2;
11403         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
11404       } else if (ISD::isConstantSplatVectorAllOnes(N2.getNode())) {
11405         Other = N1;
11406       }
11407 
11408       if (Other && Other.getOpcode() == ISD::ADD) {
11409         SDValue CondLHS = LHS, CondRHS = RHS;
11410         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
11411 
11412         // Canonicalize condition operands.
11413         if (SatCC == ISD::SETUGE) {
11414           std::swap(CondLHS, CondRHS);
11415           SatCC = ISD::SETULE;
11416         }
11417 
11418         // We can test against either of the addition operands.
11419         // x <= x+y ? x+y : ~0 --> uaddsat x, y
11420         // x+y >= x ? x+y : ~0 --> uaddsat x, y
11421         if (SatCC == ISD::SETULE && Other == CondRHS &&
11422             (OpLHS == CondLHS || OpRHS == CondLHS))
11423           return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
11424 
11425         if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
11426             (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
11427              OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
11428             CondLHS == OpLHS) {
11429           // If the RHS is a constant we have to reverse the const
11430           // canonicalization.
11431           // x >= ~C ? x+C : ~0 --> uaddsat x, C
11432           auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
11433             return Cond->getAPIntValue() == ~Op->getAPIntValue();
11434           };
11435           if (SatCC == ISD::SETULE &&
11436               ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUADDSAT))
11437             return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
11438         }
11439       }
11440     }
11441 
11442     // Match VSELECTs into sub with unsigned saturation.
11443     if (hasOperation(ISD::USUBSAT, VT)) {
11444       // Check if one of the arms of the VSELECT is a zero vector. If it's on
11445       // the left side invert the predicate to simplify logic below.
11446       SDValue Other;
11447       ISD::CondCode SatCC = CC;
11448       if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
11449         Other = N2;
11450         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
11451       } else if (ISD::isConstantSplatVectorAllZeros(N2.getNode())) {
11452         Other = N1;
11453       }
11454 
11455       // zext(x) >= y ? trunc(zext(x) - y) : 0
11456       // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
11457       // zext(x) >  y ? trunc(zext(x) - y) : 0
11458       // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
11459       if (Other && Other.getOpcode() == ISD::TRUNCATE &&
11460           Other.getOperand(0).getOpcode() == ISD::SUB &&
11461           (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
11462         SDValue OpLHS = Other.getOperand(0).getOperand(0);
11463         SDValue OpRHS = Other.getOperand(0).getOperand(1);
11464         if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
11465           if (SDValue R = getTruncatedUSUBSAT(VT, LHS.getValueType(), LHS, RHS,
11466                                               DAG, DL))
11467             return R;
11468       }
11469 
11470       if (Other && Other.getNumOperands() == 2) {
11471         SDValue CondRHS = RHS;
11472         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
11473 
11474         if (OpLHS == LHS) {
11475           // Look for a general sub with unsigned saturation first.
11476           // x >= y ? x-y : 0 --> usubsat x, y
11477           // x >  y ? x-y : 0 --> usubsat x, y
11478           if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
11479               Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
11480             return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
11481 
11482           if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
11483               OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
11484             if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
11485                 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
11486               // If the RHS is a constant we have to reverse the const
11487               // canonicalization.
11488               // x > C-1 ? x+-C : 0 --> usubsat x, C
11489               auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
11490                 return (!Op && !Cond) ||
11491                        (Op && Cond &&
11492                         Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
11493               };
11494               if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
11495                   ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT,
11496                                             /*AllowUndefs*/ true)) {
11497                 OpRHS = DAG.getNegative(OpRHS, DL, VT);
11498                 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
11499               }
11500 
11501               // Another special case: If C was a sign bit, the sub has been
11502               // canonicalized into a xor.
11503               // FIXME: Would it be better to use computeKnownBits to
11504               // determine whether it's safe to decanonicalize the xor?
11505               // x s< 0 ? x^C : 0 --> usubsat x, C
11506               APInt SplatValue;
11507               if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
11508                   ISD::isConstantSplatVector(OpRHS.getNode(), SplatValue) &&
11509                   ISD::isConstantSplatVectorAllZeros(CondRHS.getNode()) &&
11510                   SplatValue.isSignMask()) {
11511                 // Note that we have to rebuild the RHS constant here to
11512                 // ensure we don't rely on particular values of undef lanes.
11513                 OpRHS = DAG.getConstant(SplatValue, DL, VT);
11514                 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
11515               }
11516             }
11517           }
11518         }
11519       }
11520     }
11521   }
11522 
11523   if (SimplifySelectOps(N, N1, N2))
11524     return SDValue(N, 0);  // Don't revisit N.
11525 
11526   // Fold (vselect all_ones, N1, N2) -> N1
11527   if (ISD::isConstantSplatVectorAllOnes(N0.getNode()))
11528     return N1;
11529   // Fold (vselect all_zeros, N1, N2) -> N2
11530   if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
11531     return N2;
11532 
11533   // The ConvertSelectToConcatVector function is assuming both the above
11534   // checks for (vselect (build_vector all{ones,zeros) ...) have been made
11535   // and addressed.
11536   if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
11537       N2.getOpcode() == ISD::CONCAT_VECTORS &&
11538       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
11539     if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
11540       return CV;
11541   }
11542 
11543   if (SDValue V = foldVSelectOfConstants(N))
11544     return V;
11545 
11546   if (hasOperation(ISD::SRA, VT))
11547     if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
11548       return V;
11549 
11550   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
11551     return SDValue(N, 0);
11552 
11553   return SDValue();
11554 }
11555 
visitSELECT_CC(SDNode * N)11556 SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
11557   SDValue N0 = N->getOperand(0);
11558   SDValue N1 = N->getOperand(1);
11559   SDValue N2 = N->getOperand(2);
11560   SDValue N3 = N->getOperand(3);
11561   SDValue N4 = N->getOperand(4);
11562   ISD::CondCode CC = cast<CondCodeSDNode>(N4)->get();
11563 
11564   // fold select_cc lhs, rhs, x, x, cc -> x
11565   if (N2 == N3)
11566     return N2;
11567 
11568   // select_cc bool, 0, x, y, seteq -> select bool, y, x
11569   if (CC == ISD::SETEQ && !LegalTypes && N0.getValueType() == MVT::i1 &&
11570       isNullConstant(N1))
11571     return DAG.getSelect(SDLoc(N), N2.getValueType(), N0, N3, N2);
11572 
11573   // Determine if the condition we're dealing with is constant
11574   if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
11575                                   CC, SDLoc(N), false)) {
11576     AddToWorklist(SCC.getNode());
11577 
11578     // cond always true -> true val
11579     // cond always false -> false val
11580     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode()))
11581       return SCCC->isZero() ? N3 : N2;
11582 
11583     // When the condition is UNDEF, just return the first operand. This is
11584     // coherent the DAG creation, no setcc node is created in this case
11585     if (SCC->isUndef())
11586       return N2;
11587 
11588     // Fold to a simpler select_cc
11589     if (SCC.getOpcode() == ISD::SETCC) {
11590       SDValue SelectOp = DAG.getNode(
11591           ISD::SELECT_CC, SDLoc(N), N2.getValueType(), SCC.getOperand(0),
11592           SCC.getOperand(1), N2, N3, SCC.getOperand(2));
11593       SelectOp->setFlags(SCC->getFlags());
11594       return SelectOp;
11595     }
11596   }
11597 
11598   // If we can fold this based on the true/false value, do so.
11599   if (SimplifySelectOps(N, N2, N3))
11600     return SDValue(N, 0);  // Don't revisit N.
11601 
11602   // fold select_cc into other things, such as min/max/abs
11603   return SimplifySelectCC(SDLoc(N), N0, N1, N2, N3, CC);
11604 }
11605 
visitSETCC(SDNode * N)11606 SDValue DAGCombiner::visitSETCC(SDNode *N) {
11607   // setcc is very commonly used as an argument to brcond. This pattern
11608   // also lend itself to numerous combines and, as a result, it is desired
11609   // we keep the argument to a brcond as a setcc as much as possible.
11610   bool PreferSetCC =
11611       N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
11612 
11613   ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
11614   EVT VT = N->getValueType(0);
11615 
11616   //   SETCC(FREEZE(X), CONST, Cond)
11617   // =>
11618   //   FREEZE(SETCC(X, CONST, Cond))
11619   // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
11620   // isn't equivalent to true or false.
11621   // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
11622   // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
11623   //
11624   // This transformation is beneficial because visitBRCOND can fold
11625   // BRCOND(FREEZE(X)) to BRCOND(X).
11626 
11627   // Conservatively optimize integer comparisons only.
11628   if (PreferSetCC) {
11629     // Do this only when SETCC is going to be used by BRCOND.
11630 
11631     SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
11632     ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
11633     ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
11634     bool Updated = false;
11635 
11636     // Is 'X Cond C' always true or false?
11637     auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
11638       bool False = (Cond == ISD::SETULT && C->isZero()) ||
11639                    (Cond == ISD::SETLT  && C->isMinSignedValue()) ||
11640                    (Cond == ISD::SETUGT && C->isAllOnes()) ||
11641                    (Cond == ISD::SETGT  && C->isMaxSignedValue());
11642       bool True =  (Cond == ISD::SETULE && C->isAllOnes()) ||
11643                    (Cond == ISD::SETLE  && C->isMaxSignedValue()) ||
11644                    (Cond == ISD::SETUGE && C->isZero()) ||
11645                    (Cond == ISD::SETGE  && C->isMinSignedValue());
11646       return True || False;
11647     };
11648 
11649     if (N0->getOpcode() == ISD::FREEZE && N0.hasOneUse() && N1C) {
11650       if (!IsAlwaysTrueOrFalse(Cond, N1C)) {
11651         N0 = N0->getOperand(0);
11652         Updated = true;
11653       }
11654     }
11655     if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse() && N0C) {
11656       if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Cond),
11657                                N0C)) {
11658         N1 = N1->getOperand(0);
11659         Updated = true;
11660       }
11661     }
11662 
11663     if (Updated)
11664       return DAG.getFreeze(DAG.getSetCC(SDLoc(N), VT, N0, N1, Cond));
11665   }
11666 
11667   SDValue Combined = SimplifySetCC(VT, N->getOperand(0), N->getOperand(1), Cond,
11668                                    SDLoc(N), !PreferSetCC);
11669 
11670   if (!Combined)
11671     return SDValue();
11672 
11673   // If we prefer to have a setcc, and we don't, we'll try our best to
11674   // recreate one using rebuildSetCC.
11675   if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
11676     SDValue NewSetCC = rebuildSetCC(Combined);
11677 
11678     // We don't have anything interesting to combine to.
11679     if (NewSetCC.getNode() == N)
11680       return SDValue();
11681 
11682     if (NewSetCC)
11683       return NewSetCC;
11684   }
11685 
11686   return Combined;
11687 }
11688 
visitSETCCCARRY(SDNode * N)11689 SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
11690   SDValue LHS = N->getOperand(0);
11691   SDValue RHS = N->getOperand(1);
11692   SDValue Carry = N->getOperand(2);
11693   SDValue Cond = N->getOperand(3);
11694 
11695   // If Carry is false, fold to a regular SETCC.
11696   if (isNullConstant(Carry))
11697     return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
11698 
11699   return SDValue();
11700 }
11701 
11702 /// Check if N satisfies:
11703 ///   N is used once.
11704 ///   N is a Load.
11705 ///   The load is compatible with ExtOpcode. It means
11706 ///     If load has explicit zero/sign extension, ExpOpcode must have the same
11707 ///     extension.
11708 ///     Otherwise returns true.
isCompatibleLoad(SDValue N,unsigned ExtOpcode)11709 static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
11710   if (!N.hasOneUse())
11711     return false;
11712 
11713   if (!isa<LoadSDNode>(N))
11714     return false;
11715 
11716   LoadSDNode *Load = cast<LoadSDNode>(N);
11717   ISD::LoadExtType LoadExt = Load->getExtensionType();
11718   if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
11719     return true;
11720 
11721   // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
11722   // extension.
11723   if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
11724       (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
11725     return false;
11726 
11727   return true;
11728 }
11729 
11730 /// Fold
11731 ///   (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
11732 ///   (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
11733 ///   (aext (select c, load x, load y)) -> (select c, extload x, extload y)
11734 /// This function is called by the DAGCombiner when visiting sext/zext/aext
11735 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
tryToFoldExtendSelectLoad(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG)11736 static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
11737                                          SelectionDAG &DAG) {
11738   unsigned Opcode = N->getOpcode();
11739   SDValue N0 = N->getOperand(0);
11740   EVT VT = N->getValueType(0);
11741   SDLoc DL(N);
11742 
11743   assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
11744           Opcode == ISD::ANY_EXTEND) &&
11745          "Expected EXTEND dag node in input!");
11746 
11747   if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
11748       !N0.hasOneUse())
11749     return SDValue();
11750 
11751   SDValue Op1 = N0->getOperand(1);
11752   SDValue Op2 = N0->getOperand(2);
11753   if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
11754     return SDValue();
11755 
11756   auto ExtLoadOpcode = ISD::EXTLOAD;
11757   if (Opcode == ISD::SIGN_EXTEND)
11758     ExtLoadOpcode = ISD::SEXTLOAD;
11759   else if (Opcode == ISD::ZERO_EXTEND)
11760     ExtLoadOpcode = ISD::ZEXTLOAD;
11761 
11762   LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
11763   LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
11764   if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
11765       !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()))
11766     return SDValue();
11767 
11768   SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
11769   SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
11770   return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
11771 }
11772 
11773 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
11774 /// a build_vector of constants.
11775 /// This function is called by the DAGCombiner when visiting sext/zext/aext
11776 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
11777 /// Vector extends are not folded if operations are legal; this is to
11778 /// avoid introducing illegal build_vector dag nodes.
tryToFoldExtendOfConstant(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalTypes)11779 static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI,
11780                                          SelectionDAG &DAG, bool LegalTypes) {
11781   unsigned Opcode = N->getOpcode();
11782   SDValue N0 = N->getOperand(0);
11783   EVT VT = N->getValueType(0);
11784   SDLoc DL(N);
11785 
11786   assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
11787           Opcode == ISD::ANY_EXTEND ||
11788           Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
11789           Opcode == ISD::ZERO_EXTEND_VECTOR_INREG ||
11790           Opcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
11791          "Expected EXTEND dag node in input!");
11792 
11793   // fold (sext c1) -> c1
11794   // fold (zext c1) -> c1
11795   // fold (aext c1) -> c1
11796   if (isa<ConstantSDNode>(N0))
11797     return DAG.getNode(Opcode, DL, VT, N0);
11798 
11799   // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
11800   // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
11801   // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
11802   if (N0->getOpcode() == ISD::SELECT) {
11803     SDValue Op1 = N0->getOperand(1);
11804     SDValue Op2 = N0->getOperand(2);
11805     if (isa<ConstantSDNode>(Op1) && isa<ConstantSDNode>(Op2) &&
11806         (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0.getValueType(), VT))) {
11807       // For any_extend, choose sign extension of the constants to allow a
11808       // possible further transform to sign_extend_inreg.i.e.
11809       //
11810       // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
11811       // t2: i64 = any_extend t1
11812       // -->
11813       // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
11814       // -->
11815       // t4: i64 = sign_extend_inreg t3
11816       unsigned FoldOpc = Opcode;
11817       if (FoldOpc == ISD::ANY_EXTEND)
11818         FoldOpc = ISD::SIGN_EXTEND;
11819       return DAG.getSelect(DL, VT, N0->getOperand(0),
11820                            DAG.getNode(FoldOpc, DL, VT, Op1),
11821                            DAG.getNode(FoldOpc, DL, VT, Op2));
11822     }
11823   }
11824 
11825   // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
11826   // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
11827   // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
11828   EVT SVT = VT.getScalarType();
11829   if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
11830       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())))
11831     return SDValue();
11832 
11833   // We can fold this node into a build_vector.
11834   unsigned VTBits = SVT.getSizeInBits();
11835   unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
11836   SmallVector<SDValue, 8> Elts;
11837   unsigned NumElts = VT.getVectorNumElements();
11838 
11839   for (unsigned i = 0; i != NumElts; ++i) {
11840     SDValue Op = N0.getOperand(i);
11841     if (Op.isUndef()) {
11842       if (Opcode == ISD::ANY_EXTEND || Opcode == ISD::ANY_EXTEND_VECTOR_INREG)
11843         Elts.push_back(DAG.getUNDEF(SVT));
11844       else
11845         Elts.push_back(DAG.getConstant(0, DL, SVT));
11846       continue;
11847     }
11848 
11849     SDLoc DL(Op);
11850     // Get the constant value and if needed trunc it to the size of the type.
11851     // Nodes like build_vector might have constants wider than the scalar type.
11852     APInt C = cast<ConstantSDNode>(Op)->getAPIntValue().zextOrTrunc(EVTBits);
11853     if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
11854       Elts.push_back(DAG.getConstant(C.sext(VTBits), DL, SVT));
11855     else
11856       Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
11857   }
11858 
11859   return DAG.getBuildVector(VT, DL, Elts);
11860 }
11861 
11862 // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
11863 // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
11864 // transformation. Returns true if extension are possible and the above
11865 // mentioned transformation is profitable.
ExtendUsesToFormExtLoad(EVT VT,SDNode * N,SDValue N0,unsigned ExtOpc,SmallVectorImpl<SDNode * > & ExtendNodes,const TargetLowering & TLI)11866 static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
11867                                     unsigned ExtOpc,
11868                                     SmallVectorImpl<SDNode *> &ExtendNodes,
11869                                     const TargetLowering &TLI) {
11870   bool HasCopyToRegUses = false;
11871   bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
11872   for (SDNode::use_iterator UI = N0->use_begin(), UE = N0->use_end(); UI != UE;
11873        ++UI) {
11874     SDNode *User = *UI;
11875     if (User == N)
11876       continue;
11877     if (UI.getUse().getResNo() != N0.getResNo())
11878       continue;
11879     // FIXME: Only extend SETCC N, N and SETCC N, c for now.
11880     if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
11881       ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
11882       if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(CC))
11883         // Sign bits will be lost after a zext.
11884         return false;
11885       bool Add = false;
11886       for (unsigned i = 0; i != 2; ++i) {
11887         SDValue UseOp = User->getOperand(i);
11888         if (UseOp == N0)
11889           continue;
11890         if (!isa<ConstantSDNode>(UseOp))
11891           return false;
11892         Add = true;
11893       }
11894       if (Add)
11895         ExtendNodes.push_back(User);
11896       continue;
11897     }
11898     // If truncates aren't free and there are users we can't
11899     // extend, it isn't worthwhile.
11900     if (!isTruncFree)
11901       return false;
11902     // Remember if this value is live-out.
11903     if (User->getOpcode() == ISD::CopyToReg)
11904       HasCopyToRegUses = true;
11905   }
11906 
11907   if (HasCopyToRegUses) {
11908     bool BothLiveOut = false;
11909     for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end();
11910          UI != UE; ++UI) {
11911       SDUse &Use = UI.getUse();
11912       if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
11913         BothLiveOut = true;
11914         break;
11915       }
11916     }
11917     if (BothLiveOut)
11918       // Both unextended and extended values are live out. There had better be
11919       // a good reason for the transformation.
11920       return ExtendNodes.size();
11921   }
11922   return true;
11923 }
11924 
ExtendSetCCUses(const SmallVectorImpl<SDNode * > & SetCCs,SDValue OrigLoad,SDValue ExtLoad,ISD::NodeType ExtType)11925 void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
11926                                   SDValue OrigLoad, SDValue ExtLoad,
11927                                   ISD::NodeType ExtType) {
11928   // Extend SetCC uses if necessary.
11929   SDLoc DL(ExtLoad);
11930   for (SDNode *SetCC : SetCCs) {
11931     SmallVector<SDValue, 4> Ops;
11932 
11933     for (unsigned j = 0; j != 2; ++j) {
11934       SDValue SOp = SetCC->getOperand(j);
11935       if (SOp == OrigLoad)
11936         Ops.push_back(ExtLoad);
11937       else
11938         Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
11939     }
11940 
11941     Ops.push_back(SetCC->getOperand(2));
11942     CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops));
11943   }
11944 }
11945 
11946 // FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
CombineExtLoad(SDNode * N)11947 SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
11948   SDValue N0 = N->getOperand(0);
11949   EVT DstVT = N->getValueType(0);
11950   EVT SrcVT = N0.getValueType();
11951 
11952   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
11953           N->getOpcode() == ISD::ZERO_EXTEND) &&
11954          "Unexpected node type (not an extend)!");
11955 
11956   // fold (sext (load x)) to multiple smaller sextloads; same for zext.
11957   // For example, on a target with legal v4i32, but illegal v8i32, turn:
11958   //   (v8i32 (sext (v8i16 (load x))))
11959   // into:
11960   //   (v8i32 (concat_vectors (v4i32 (sextload x)),
11961   //                          (v4i32 (sextload (x + 16)))))
11962   // Where uses of the original load, i.e.:
11963   //   (v8i16 (load x))
11964   // are replaced with:
11965   //   (v8i16 (truncate
11966   //     (v8i32 (concat_vectors (v4i32 (sextload x)),
11967   //                            (v4i32 (sextload (x + 16)))))))
11968   //
11969   // This combine is only applicable to illegal, but splittable, vectors.
11970   // All legal types, and illegal non-vector types, are handled elsewhere.
11971   // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
11972   //
11973   if (N0->getOpcode() != ISD::LOAD)
11974     return SDValue();
11975 
11976   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11977 
11978   if (!ISD::isNON_EXTLoad(LN0) || !ISD::isUNINDEXEDLoad(LN0) ||
11979       !N0.hasOneUse() || !LN0->isSimple() ||
11980       !DstVT.isVector() || !DstVT.isPow2VectorType() ||
11981       !TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
11982     return SDValue();
11983 
11984   SmallVector<SDNode *, 4> SetCCs;
11985   if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
11986     return SDValue();
11987 
11988   ISD::LoadExtType ExtType =
11989       N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
11990 
11991   // Try to split the vector types to get down to legal types.
11992   EVT SplitSrcVT = SrcVT;
11993   EVT SplitDstVT = DstVT;
11994   while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
11995          SplitSrcVT.getVectorNumElements() > 1) {
11996     SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
11997     SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
11998   }
11999 
12000   if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
12001     return SDValue();
12002 
12003   assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
12004 
12005   SDLoc DL(N);
12006   const unsigned NumSplits =
12007       DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
12008   const unsigned Stride = SplitSrcVT.getStoreSize();
12009   SmallVector<SDValue, 4> Loads;
12010   SmallVector<SDValue, 4> Chains;
12011 
12012   SDValue BasePtr = LN0->getBasePtr();
12013   for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
12014     const unsigned Offset = Idx * Stride;
12015     const Align Align = commonAlignment(LN0->getAlign(), Offset);
12016 
12017     SDValue SplitLoad = DAG.getExtLoad(
12018         ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(), BasePtr,
12019         LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT, Align,
12020         LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
12021 
12022     BasePtr = DAG.getMemBasePlusOffset(BasePtr, TypeSize::Fixed(Stride), DL);
12023 
12024     Loads.push_back(SplitLoad.getValue(0));
12025     Chains.push_back(SplitLoad.getValue(1));
12026   }
12027 
12028   SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
12029   SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
12030 
12031   // Simplify TF.
12032   AddToWorklist(NewChain.getNode());
12033 
12034   CombineTo(N, NewValue);
12035 
12036   // Replace uses of the original load (before extension)
12037   // with a truncate of the concatenated sextloaded vectors.
12038   SDValue Trunc =
12039       DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
12040   ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
12041   CombineTo(N0.getNode(), Trunc, NewChain);
12042   return SDValue(N, 0); // Return N so it doesn't get rechecked!
12043 }
12044 
12045 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
12046 //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
CombineZExtLogicopShiftLoad(SDNode * N)12047 SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
12048   assert(N->getOpcode() == ISD::ZERO_EXTEND);
12049   EVT VT = N->getValueType(0);
12050   EVT OrigVT = N->getOperand(0).getValueType();
12051   if (TLI.isZExtFree(OrigVT, VT))
12052     return SDValue();
12053 
12054   // and/or/xor
12055   SDValue N0 = N->getOperand(0);
12056   if (!(N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
12057         N0.getOpcode() == ISD::XOR) ||
12058       N0.getOperand(1).getOpcode() != ISD::Constant ||
12059       (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
12060     return SDValue();
12061 
12062   // shl/shr
12063   SDValue N1 = N0->getOperand(0);
12064   if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
12065       N1.getOperand(1).getOpcode() != ISD::Constant ||
12066       (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
12067     return SDValue();
12068 
12069   // load
12070   if (!isa<LoadSDNode>(N1.getOperand(0)))
12071     return SDValue();
12072   LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
12073   EVT MemVT = Load->getMemoryVT();
12074   if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
12075       Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
12076     return SDValue();
12077 
12078 
12079   // If the shift op is SHL, the logic op must be AND, otherwise the result
12080   // will be wrong.
12081   if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
12082     return SDValue();
12083 
12084   if (!N0.hasOneUse() || !N1.hasOneUse())
12085     return SDValue();
12086 
12087   SmallVector<SDNode*, 4> SetCCs;
12088   if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
12089                                ISD::ZERO_EXTEND, SetCCs, TLI))
12090     return SDValue();
12091 
12092   // Actually do the transformation.
12093   SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
12094                                    Load->getChain(), Load->getBasePtr(),
12095                                    Load->getMemoryVT(), Load->getMemOperand());
12096 
12097   SDLoc DL1(N1);
12098   SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
12099                               N1.getOperand(1));
12100 
12101   APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
12102   SDLoc DL0(N0);
12103   SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
12104                             DAG.getConstant(Mask, DL0, VT));
12105 
12106   ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
12107   CombineTo(N, And);
12108   if (SDValue(Load, 0).hasOneUse()) {
12109     DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
12110   } else {
12111     SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
12112                                 Load->getValueType(0), ExtLoad);
12113     CombineTo(Load, Trunc, ExtLoad.getValue(1));
12114   }
12115 
12116   // N0 is dead at this point.
12117   recursivelyDeleteUnusedNodes(N0.getNode());
12118 
12119   return SDValue(N,0); // Return N so it doesn't get rechecked!
12120 }
12121 
12122 /// If we're narrowing or widening the result of a vector select and the final
12123 /// size is the same size as a setcc (compare) feeding the select, then try to
12124 /// apply the cast operation to the select's operands because matching vector
12125 /// sizes for a select condition and other operands should be more efficient.
matchVSelectOpSizesWithSetCC(SDNode * Cast)12126 SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
12127   unsigned CastOpcode = Cast->getOpcode();
12128   assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
12129           CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
12130           CastOpcode == ISD::FP_ROUND) &&
12131          "Unexpected opcode for vector select narrowing/widening");
12132 
12133   // We only do this transform before legal ops because the pattern may be
12134   // obfuscated by target-specific operations after legalization. Do not create
12135   // an illegal select op, however, because that may be difficult to lower.
12136   EVT VT = Cast->getValueType(0);
12137   if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
12138     return SDValue();
12139 
12140   SDValue VSel = Cast->getOperand(0);
12141   if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
12142       VSel.getOperand(0).getOpcode() != ISD::SETCC)
12143     return SDValue();
12144 
12145   // Does the setcc have the same vector size as the casted select?
12146   SDValue SetCC = VSel.getOperand(0);
12147   EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
12148   if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
12149     return SDValue();
12150 
12151   // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
12152   SDValue A = VSel.getOperand(1);
12153   SDValue B = VSel.getOperand(2);
12154   SDValue CastA, CastB;
12155   SDLoc DL(Cast);
12156   if (CastOpcode == ISD::FP_ROUND) {
12157     // FP_ROUND (fptrunc) has an extra flag operand to pass along.
12158     CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
12159     CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
12160   } else {
12161     CastA = DAG.getNode(CastOpcode, DL, VT, A);
12162     CastB = DAG.getNode(CastOpcode, DL, VT, B);
12163   }
12164   return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
12165 }
12166 
12167 // fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
12168 // fold ([s|z]ext (     extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
tryToFoldExtOfExtload(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType)12169 static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
12170                                      const TargetLowering &TLI, EVT VT,
12171                                      bool LegalOperations, SDNode *N,
12172                                      SDValue N0, ISD::LoadExtType ExtLoadType) {
12173   SDNode *N0Node = N0.getNode();
12174   bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node)
12175                                                    : ISD::isZEXTLoad(N0Node);
12176   if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) ||
12177       !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse())
12178     return SDValue();
12179 
12180   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12181   EVT MemVT = LN0->getMemoryVT();
12182   if ((LegalOperations || !LN0->isSimple() ||
12183        VT.isVector()) &&
12184       !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
12185     return SDValue();
12186 
12187   SDValue ExtLoad =
12188       DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
12189                      LN0->getBasePtr(), MemVT, LN0->getMemOperand());
12190   Combiner.CombineTo(N, ExtLoad);
12191   DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
12192   if (LN0->use_empty())
12193     Combiner.recursivelyDeleteUnusedNodes(LN0);
12194   return SDValue(N, 0); // Return N so it doesn't get rechecked!
12195 }
12196 
12197 // fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
12198 // Only generate vector extloads when 1) they're legal, and 2) they are
12199 // deemed desirable by the target.
tryToFoldExtOfLoad(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)12200 static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
12201                                   const TargetLowering &TLI, EVT VT,
12202                                   bool LegalOperations, SDNode *N, SDValue N0,
12203                                   ISD::LoadExtType ExtLoadType,
12204                                   ISD::NodeType ExtOpc) {
12205   // TODO: isFixedLengthVector() should be removed and any negative effects on
12206   // code generation being the result of that target's implementation of
12207   // isVectorLoadExtDesirable().
12208   if (!ISD::isNON_EXTLoad(N0.getNode()) ||
12209       !ISD::isUNINDEXEDLoad(N0.getNode()) ||
12210       ((LegalOperations || VT.isFixedLengthVector() ||
12211         !cast<LoadSDNode>(N0)->isSimple()) &&
12212        !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType())))
12213     return {};
12214 
12215   bool DoXform = true;
12216   SmallVector<SDNode *, 4> SetCCs;
12217   if (!N0.hasOneUse())
12218     DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI);
12219   if (VT.isVector())
12220     DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
12221   if (!DoXform)
12222     return {};
12223 
12224   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12225   SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
12226                                    LN0->getBasePtr(), N0.getValueType(),
12227                                    LN0->getMemOperand());
12228   Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc);
12229   // If the load value is used only by N, replace it via CombineTo N.
12230   bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
12231   Combiner.CombineTo(N, ExtLoad);
12232   if (NoReplaceTrunc) {
12233     DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
12234     Combiner.recursivelyDeleteUnusedNodes(LN0);
12235   } else {
12236     SDValue Trunc =
12237         DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
12238     Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1));
12239   }
12240   return SDValue(N, 0); // Return N so it doesn't get rechecked!
12241 }
12242 
tryToFoldExtOfMaskedLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)12243 static SDValue tryToFoldExtOfMaskedLoad(SelectionDAG &DAG,
12244                                         const TargetLowering &TLI, EVT VT,
12245                                         SDNode *N, SDValue N0,
12246                                         ISD::LoadExtType ExtLoadType,
12247                                         ISD::NodeType ExtOpc) {
12248   if (!N0.hasOneUse())
12249     return SDValue();
12250 
12251   MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0);
12252   if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
12253     return SDValue();
12254 
12255   if (!TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0)))
12256     return SDValue();
12257 
12258   if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
12259     return SDValue();
12260 
12261   SDLoc dl(Ld);
12262   SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru());
12263   SDValue NewLoad = DAG.getMaskedLoad(
12264       VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(),
12265       PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(),
12266       ExtLoadType, Ld->isExpandingLoad());
12267   DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1));
12268   return NewLoad;
12269 }
12270 
foldExtendedSignBitTest(SDNode * N,SelectionDAG & DAG,bool LegalOperations)12271 static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
12272                                        bool LegalOperations) {
12273   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
12274           N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
12275 
12276   SDValue SetCC = N->getOperand(0);
12277   if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
12278       !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
12279     return SDValue();
12280 
12281   SDValue X = SetCC.getOperand(0);
12282   SDValue Ones = SetCC.getOperand(1);
12283   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
12284   EVT VT = N->getValueType(0);
12285   EVT XVT = X.getValueType();
12286   // setge X, C is canonicalized to setgt, so we do not need to match that
12287   // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
12288   // not require the 'not' op.
12289   if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
12290     // Invert and smear/shift the sign bit:
12291     // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
12292     // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
12293     SDLoc DL(N);
12294     unsigned ShCt = VT.getSizeInBits() - 1;
12295     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12296     if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
12297       SDValue NotX = DAG.getNOT(DL, X, VT);
12298       SDValue ShiftAmount = DAG.getConstant(ShCt, DL, VT);
12299       auto ShiftOpcode =
12300         N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
12301       return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
12302     }
12303   }
12304   return SDValue();
12305 }
12306 
foldSextSetcc(SDNode * N)12307 SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
12308   SDValue N0 = N->getOperand(0);
12309   if (N0.getOpcode() != ISD::SETCC)
12310     return SDValue();
12311 
12312   SDValue N00 = N0.getOperand(0);
12313   SDValue N01 = N0.getOperand(1);
12314   ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
12315   EVT VT = N->getValueType(0);
12316   EVT N00VT = N00.getValueType();
12317   SDLoc DL(N);
12318 
12319   // Propagate fast-math-flags.
12320   SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
12321 
12322   // On some architectures (such as SSE/NEON/etc) the SETCC result type is
12323   // the same size as the compared operands. Try to optimize sext(setcc())
12324   // if this is the case.
12325   if (VT.isVector() && !LegalOperations &&
12326       TLI.getBooleanContents(N00VT) ==
12327           TargetLowering::ZeroOrNegativeOneBooleanContent) {
12328     EVT SVT = getSetCCResultType(N00VT);
12329 
12330     // If we already have the desired type, don't change it.
12331     if (SVT != N0.getValueType()) {
12332       // We know that the # elements of the results is the same as the
12333       // # elements of the compare (and the # elements of the compare result
12334       // for that matter).  Check to see that they are the same size.  If so,
12335       // we know that the element size of the sext'd result matches the
12336       // element size of the compare operands.
12337       if (VT.getSizeInBits() == SVT.getSizeInBits())
12338         return DAG.getSetCC(DL, VT, N00, N01, CC);
12339 
12340       // If the desired elements are smaller or larger than the source
12341       // elements, we can use a matching integer vector type and then
12342       // truncate/sign extend.
12343       EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
12344       if (SVT == MatchingVecType) {
12345         SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
12346         return DAG.getSExtOrTrunc(VsetCC, DL, VT);
12347       }
12348     }
12349 
12350     // Try to eliminate the sext of a setcc by zexting the compare operands.
12351     if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(ISD::SETCC, VT) &&
12352         !TLI.isOperationLegalOrCustom(ISD::SETCC, SVT)) {
12353       bool IsSignedCmp = ISD::isSignedIntSetCC(CC);
12354       unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
12355       unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
12356 
12357       // We have an unsupported narrow vector compare op that would be legal
12358       // if extended to the destination type. See if the compare operands
12359       // can be freely extended to the destination type.
12360       auto IsFreeToExtend = [&](SDValue V) {
12361         if (isConstantOrConstantVector(V, /*NoOpaques*/ true))
12362           return true;
12363         // Match a simple, non-extended load that can be converted to a
12364         // legal {z/s}ext-load.
12365         // TODO: Allow widening of an existing {z/s}ext-load?
12366         if (!(ISD::isNON_EXTLoad(V.getNode()) &&
12367               ISD::isUNINDEXEDLoad(V.getNode()) &&
12368               cast<LoadSDNode>(V)->isSimple() &&
12369               TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType())))
12370           return false;
12371 
12372         // Non-chain users of this value must either be the setcc in this
12373         // sequence or extends that can be folded into the new {z/s}ext-load.
12374         for (SDNode::use_iterator UI = V->use_begin(), UE = V->use_end();
12375              UI != UE; ++UI) {
12376           // Skip uses of the chain and the setcc.
12377           SDNode *User = *UI;
12378           if (UI.getUse().getResNo() != 0 || User == N0.getNode())
12379             continue;
12380           // Extra users must have exactly the same cast we are about to create.
12381           // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
12382           //       is enhanced similarly.
12383           if (User->getOpcode() != ExtOpcode || User->getValueType(0) != VT)
12384             return false;
12385         }
12386         return true;
12387       };
12388 
12389       if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
12390         SDValue Ext0 = DAG.getNode(ExtOpcode, DL, VT, N00);
12391         SDValue Ext1 = DAG.getNode(ExtOpcode, DL, VT, N01);
12392         return DAG.getSetCC(DL, VT, Ext0, Ext1, CC);
12393       }
12394     }
12395   }
12396 
12397   // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
12398   // Here, T can be 1 or -1, depending on the type of the setcc and
12399   // getBooleanContents().
12400   unsigned SetCCWidth = N0.getScalarValueSizeInBits();
12401 
12402   // To determine the "true" side of the select, we need to know the high bit
12403   // of the value returned by the setcc if it evaluates to true.
12404   // If the type of the setcc is i1, then the true case of the select is just
12405   // sext(i1 1), that is, -1.
12406   // If the type of the setcc is larger (say, i8) then the value of the high
12407   // bit depends on getBooleanContents(), so ask TLI for a real "true" value
12408   // of the appropriate width.
12409   SDValue ExtTrueVal = (SetCCWidth == 1)
12410                            ? DAG.getAllOnesConstant(DL, VT)
12411                            : DAG.getBoolConstant(true, DL, VT, N00VT);
12412   SDValue Zero = DAG.getConstant(0, DL, VT);
12413   if (SDValue SCC = SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
12414     return SCC;
12415 
12416   if (!VT.isVector() && !shouldConvertSelectOfConstantsToMath(N0, VT, TLI)) {
12417     EVT SetCCVT = getSetCCResultType(N00VT);
12418     // Don't do this transform for i1 because there's a select transform
12419     // that would reverse it.
12420     // TODO: We should not do this transform at all without a target hook
12421     // because a sext is likely cheaper than a select?
12422     if (SetCCVT.getScalarSizeInBits() != 1 &&
12423         (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
12424       SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
12425       return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
12426     }
12427   }
12428 
12429   return SDValue();
12430 }
12431 
visitSIGN_EXTEND(SDNode * N)12432 SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
12433   SDValue N0 = N->getOperand(0);
12434   EVT VT = N->getValueType(0);
12435   SDLoc DL(N);
12436 
12437   if (VT.isVector())
12438     if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
12439       return FoldedVOp;
12440 
12441   // sext(undef) = 0 because the top bit will all be the same.
12442   if (N0.isUndef())
12443     return DAG.getConstant(0, DL, VT);
12444 
12445   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
12446     return Res;
12447 
12448   // fold (sext (sext x)) -> (sext x)
12449   // fold (sext (aext x)) -> (sext x)
12450   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
12451     return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
12452 
12453   // fold (sext (sext_inreg x)) -> (sext (trunc x))
12454   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
12455     SDValue N00 = N0.getOperand(0);
12456     EVT ExtVT = cast<VTSDNode>(N0->getOperand(1))->getVT();
12457     if (N00.getOpcode() == ISD::TRUNCATE &&
12458         (!LegalTypes || TLI.isTypeLegal(ExtVT))) {
12459       SDValue T = DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N00.getOperand(0));
12460       return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, T);
12461     }
12462   }
12463 
12464   if (N0.getOpcode() == ISD::TRUNCATE) {
12465     // fold (sext (truncate (load x))) -> (sext (smaller load x))
12466     // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
12467     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
12468       SDNode *oye = N0.getOperand(0).getNode();
12469       if (NarrowLoad.getNode() != N0.getNode()) {
12470         CombineTo(N0.getNode(), NarrowLoad);
12471         // CombineTo deleted the truncate, if needed, but not what's under it.
12472         AddToWorklist(oye);
12473       }
12474       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
12475     }
12476 
12477     // See if the value being truncated is already sign extended.  If so, just
12478     // eliminate the trunc/sext pair.
12479     SDValue Op = N0.getOperand(0);
12480     unsigned OpBits   = Op.getScalarValueSizeInBits();
12481     unsigned MidBits  = N0.getScalarValueSizeInBits();
12482     unsigned DestBits = VT.getScalarSizeInBits();
12483     unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
12484 
12485     if (OpBits == DestBits) {
12486       // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
12487       // bits, it is already ready.
12488       if (NumSignBits > DestBits-MidBits)
12489         return Op;
12490     } else if (OpBits < DestBits) {
12491       // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
12492       // bits, just sext from i32.
12493       if (NumSignBits > OpBits-MidBits)
12494         return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
12495     } else {
12496       // Op is i64, Mid is i8, and Dest is i32.  If Op has more than 56 sign
12497       // bits, just truncate to i32.
12498       if (NumSignBits > OpBits-MidBits)
12499         return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
12500     }
12501 
12502     // fold (sext (truncate x)) -> (sextinreg x).
12503     if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG,
12504                                                  N0.getValueType())) {
12505       if (OpBits < DestBits)
12506         Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
12507       else if (OpBits > DestBits)
12508         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
12509       return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
12510                          DAG.getValueType(N0.getValueType()));
12511     }
12512   }
12513 
12514   // Try to simplify (sext (load x)).
12515   if (SDValue foldedExt =
12516           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
12517                              ISD::SEXTLOAD, ISD::SIGN_EXTEND))
12518     return foldedExt;
12519 
12520   if (SDValue foldedExt =
12521       tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::SEXTLOAD,
12522                                ISD::SIGN_EXTEND))
12523     return foldedExt;
12524 
12525   // fold (sext (load x)) to multiple smaller sextloads.
12526   // Only on illegal but splittable vectors.
12527   if (SDValue ExtLoad = CombineExtLoad(N))
12528     return ExtLoad;
12529 
12530   // Try to simplify (sext (sextload x)).
12531   if (SDValue foldedExt = tryToFoldExtOfExtload(
12532           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
12533     return foldedExt;
12534 
12535   // fold (sext (and/or/xor (load x), cst)) ->
12536   //      (and/or/xor (sextload x), (sext cst))
12537   if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
12538        N0.getOpcode() == ISD::XOR) &&
12539       isa<LoadSDNode>(N0.getOperand(0)) &&
12540       N0.getOperand(1).getOpcode() == ISD::Constant &&
12541       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
12542     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
12543     EVT MemVT = LN00->getMemoryVT();
12544     if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
12545       LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
12546       SmallVector<SDNode*, 4> SetCCs;
12547       bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
12548                                              ISD::SIGN_EXTEND, SetCCs, TLI);
12549       if (DoXform) {
12550         SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
12551                                          LN00->getChain(), LN00->getBasePtr(),
12552                                          LN00->getMemoryVT(),
12553                                          LN00->getMemOperand());
12554         APInt Mask = N0.getConstantOperandAPInt(1).sext(VT.getSizeInBits());
12555         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
12556                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
12557         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
12558         bool NoReplaceTruncAnd = !N0.hasOneUse();
12559         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
12560         CombineTo(N, And);
12561         // If N0 has multiple uses, change other uses as well.
12562         if (NoReplaceTruncAnd) {
12563           SDValue TruncAnd =
12564               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
12565           CombineTo(N0.getNode(), TruncAnd);
12566         }
12567         if (NoReplaceTrunc) {
12568           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
12569         } else {
12570           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
12571                                       LN00->getValueType(0), ExtLoad);
12572           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
12573         }
12574         return SDValue(N,0); // Return N so it doesn't get rechecked!
12575       }
12576     }
12577   }
12578 
12579   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
12580     return V;
12581 
12582   if (SDValue V = foldSextSetcc(N))
12583     return V;
12584 
12585   // fold (sext x) -> (zext x) if the sign bit is known zero.
12586   if ((!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
12587       DAG.SignBitIsZero(N0))
12588     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0);
12589 
12590   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
12591     return NewVSel;
12592 
12593   // Eliminate this sign extend by doing a negation in the destination type:
12594   // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
12595   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
12596       isNullOrNullSplat(N0.getOperand(0)) &&
12597       N0.getOperand(1).getOpcode() == ISD::ZERO_EXTEND &&
12598       TLI.isOperationLegalOrCustom(ISD::SUB, VT)) {
12599     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(1).getOperand(0), DL, VT);
12600     return DAG.getNegative(Zext, DL, VT);
12601   }
12602   // Eliminate this sign extend by doing a decrement in the destination type:
12603   // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
12604   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
12605       isAllOnesOrAllOnesSplat(N0.getOperand(1)) &&
12606       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
12607       TLI.isOperationLegalOrCustom(ISD::ADD, VT)) {
12608     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
12609     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
12610   }
12611 
12612   // fold sext (not i1 X) -> add (zext i1 X), -1
12613   // TODO: This could be extended to handle bool vectors.
12614   if (N0.getValueType() == MVT::i1 && isBitwiseNot(N0) && N0.hasOneUse() &&
12615       (!LegalOperations || (TLI.isOperationLegal(ISD::ZERO_EXTEND, VT) &&
12616                             TLI.isOperationLegal(ISD::ADD, VT)))) {
12617     // If we can eliminate the 'not', the sext form should be better
12618     if (SDValue NewXor = visitXOR(N0.getNode())) {
12619       // Returning N0 is a form of in-visit replacement that may have
12620       // invalidated N0.
12621       if (NewXor.getNode() == N0.getNode()) {
12622         // Return SDValue here as the xor should have already been replaced in
12623         // this sext.
12624         return SDValue();
12625       }
12626 
12627       // Return a new sext with the new xor.
12628       return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewXor);
12629     }
12630 
12631     SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
12632     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
12633   }
12634 
12635   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
12636     return Res;
12637 
12638   return SDValue();
12639 }
12640 
12641 // isTruncateOf - If N is a truncate of some other value, return true, record
12642 // the value being truncated in Op and which of Op's bits are zero/one in Known.
12643 // This function computes KnownBits to avoid a duplicated call to
12644 // computeKnownBits in the caller.
isTruncateOf(SelectionDAG & DAG,SDValue N,SDValue & Op,KnownBits & Known)12645 static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
12646                          KnownBits &Known) {
12647   if (N->getOpcode() == ISD::TRUNCATE) {
12648     Op = N->getOperand(0);
12649     Known = DAG.computeKnownBits(Op);
12650     return true;
12651   }
12652 
12653   if (N.getOpcode() != ISD::SETCC ||
12654       N.getValueType().getScalarType() != MVT::i1 ||
12655       cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
12656     return false;
12657 
12658   SDValue Op0 = N->getOperand(0);
12659   SDValue Op1 = N->getOperand(1);
12660   assert(Op0.getValueType() == Op1.getValueType());
12661 
12662   if (isNullOrNullSplat(Op0))
12663     Op = Op1;
12664   else if (isNullOrNullSplat(Op1))
12665     Op = Op0;
12666   else
12667     return false;
12668 
12669   Known = DAG.computeKnownBits(Op);
12670 
12671   return (Known.Zero | 1).isAllOnes();
12672 }
12673 
12674 /// Given an extending node with a pop-count operand, if the target does not
12675 /// support a pop-count in the narrow source type but does support it in the
12676 /// destination type, widen the pop-count to the destination type.
widenCtPop(SDNode * Extend,SelectionDAG & DAG)12677 static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG) {
12678   assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
12679           Extend->getOpcode() == ISD::ANY_EXTEND) && "Expected extend op");
12680 
12681   SDValue CtPop = Extend->getOperand(0);
12682   if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
12683     return SDValue();
12684 
12685   EVT VT = Extend->getValueType(0);
12686   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12687   if (TLI.isOperationLegalOrCustom(ISD::CTPOP, CtPop.getValueType()) ||
12688       !TLI.isOperationLegalOrCustom(ISD::CTPOP, VT))
12689     return SDValue();
12690 
12691   // zext (ctpop X) --> ctpop (zext X)
12692   SDLoc DL(Extend);
12693   SDValue NewZext = DAG.getZExtOrTrunc(CtPop.getOperand(0), DL, VT);
12694   return DAG.getNode(ISD::CTPOP, DL, VT, NewZext);
12695 }
12696 
12697 // If we have (zext (abs X)) where X is a type that will be promoted by type
12698 // legalization, convert to (abs (sext X)). But don't extend past a legal type.
widenAbs(SDNode * Extend,SelectionDAG & DAG)12699 static SDValue widenAbs(SDNode *Extend, SelectionDAG &DAG) {
12700   assert(Extend->getOpcode() == ISD::ZERO_EXTEND && "Expected zero extend.");
12701 
12702   EVT VT = Extend->getValueType(0);
12703   if (VT.isVector())
12704     return SDValue();
12705 
12706   SDValue Abs = Extend->getOperand(0);
12707   if (Abs.getOpcode() != ISD::ABS || !Abs.hasOneUse())
12708     return SDValue();
12709 
12710   EVT AbsVT = Abs.getValueType();
12711   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12712   if (TLI.getTypeAction(*DAG.getContext(), AbsVT) !=
12713       TargetLowering::TypePromoteInteger)
12714     return SDValue();
12715 
12716   EVT LegalVT = TLI.getTypeToTransformTo(*DAG.getContext(), AbsVT);
12717 
12718   SDValue SExt =
12719       DAG.getNode(ISD::SIGN_EXTEND, SDLoc(Abs), LegalVT, Abs.getOperand(0));
12720   SDValue NewAbs = DAG.getNode(ISD::ABS, SDLoc(Abs), LegalVT, SExt);
12721   return DAG.getZExtOrTrunc(NewAbs, SDLoc(Extend), VT);
12722 }
12723 
visitZERO_EXTEND(SDNode * N)12724 SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
12725   SDValue N0 = N->getOperand(0);
12726   EVT VT = N->getValueType(0);
12727 
12728   if (VT.isVector())
12729     if (SDValue FoldedVOp = SimplifyVCastOp(N, SDLoc(N)))
12730       return FoldedVOp;
12731 
12732   // zext(undef) = 0
12733   if (N0.isUndef())
12734     return DAG.getConstant(0, SDLoc(N), VT);
12735 
12736   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
12737     return Res;
12738 
12739   // fold (zext (zext x)) -> (zext x)
12740   // fold (zext (aext x)) -> (zext x)
12741   if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
12742     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT,
12743                        N0.getOperand(0));
12744 
12745   // fold (zext (truncate x)) -> (zext x) or
12746   //      (zext (truncate x)) -> (truncate x)
12747   // This is valid when the truncated bits of x are already zero.
12748   SDValue Op;
12749   KnownBits Known;
12750   if (isTruncateOf(DAG, N0, Op, Known)) {
12751     APInt TruncatedBits =
12752       (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
12753       APInt(Op.getScalarValueSizeInBits(), 0) :
12754       APInt::getBitsSet(Op.getScalarValueSizeInBits(),
12755                         N0.getScalarValueSizeInBits(),
12756                         std::min(Op.getScalarValueSizeInBits(),
12757                                  VT.getScalarSizeInBits()));
12758     if (TruncatedBits.isSubsetOf(Known.Zero))
12759       return DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
12760   }
12761 
12762   // fold (zext (truncate x)) -> (and x, mask)
12763   if (N0.getOpcode() == ISD::TRUNCATE) {
12764     // fold (zext (truncate (load x))) -> (zext (smaller load x))
12765     // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
12766     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
12767       SDNode *oye = N0.getOperand(0).getNode();
12768       if (NarrowLoad.getNode() != N0.getNode()) {
12769         CombineTo(N0.getNode(), NarrowLoad);
12770         // CombineTo deleted the truncate, if needed, but not what's under it.
12771         AddToWorklist(oye);
12772       }
12773       return SDValue(N, 0); // Return N so it doesn't get rechecked!
12774     }
12775 
12776     EVT SrcVT = N0.getOperand(0).getValueType();
12777     EVT MinVT = N0.getValueType();
12778 
12779     // Try to mask before the extension to avoid having to generate a larger mask,
12780     // possibly over several sub-vectors.
12781     if (SrcVT.bitsLT(VT) && VT.isVector()) {
12782       if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
12783                                TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
12784         SDValue Op = N0.getOperand(0);
12785         Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
12786         AddToWorklist(Op.getNode());
12787         SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
12788         // Transfer the debug info; the new node is equivalent to N0.
12789         DAG.transferDbgValues(N0, ZExtOrTrunc);
12790         return ZExtOrTrunc;
12791       }
12792     }
12793 
12794     if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
12795       SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
12796       AddToWorklist(Op.getNode());
12797       SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
12798       // We may safely transfer the debug info describing the truncate node over
12799       // to the equivalent and operation.
12800       DAG.transferDbgValues(N0, And);
12801       return And;
12802     }
12803   }
12804 
12805   // Fold (zext (and (trunc x), cst)) -> (and x, cst),
12806   // if either of the casts is not free.
12807   if (N0.getOpcode() == ISD::AND &&
12808       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
12809       N0.getOperand(1).getOpcode() == ISD::Constant &&
12810       (!TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
12811                            N0.getValueType()) ||
12812        !TLI.isZExtFree(N0.getValueType(), VT))) {
12813     SDValue X = N0.getOperand(0).getOperand(0);
12814     X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
12815     APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
12816     SDLoc DL(N);
12817     return DAG.getNode(ISD::AND, DL, VT,
12818                        X, DAG.getConstant(Mask, DL, VT));
12819   }
12820 
12821   // Try to simplify (zext (load x)).
12822   if (SDValue foldedExt =
12823           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
12824                              ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
12825     return foldedExt;
12826 
12827   if (SDValue foldedExt =
12828       tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::ZEXTLOAD,
12829                                ISD::ZERO_EXTEND))
12830     return foldedExt;
12831 
12832   // fold (zext (load x)) to multiple smaller zextloads.
12833   // Only on illegal but splittable vectors.
12834   if (SDValue ExtLoad = CombineExtLoad(N))
12835     return ExtLoad;
12836 
12837   // fold (zext (and/or/xor (load x), cst)) ->
12838   //      (and/or/xor (zextload x), (zext cst))
12839   // Unless (and (load x) cst) will match as a zextload already and has
12840   // additional users.
12841   if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
12842        N0.getOpcode() == ISD::XOR) &&
12843       isa<LoadSDNode>(N0.getOperand(0)) &&
12844       N0.getOperand(1).getOpcode() == ISD::Constant &&
12845       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
12846     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
12847     EVT MemVT = LN00->getMemoryVT();
12848     if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
12849         LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
12850       bool DoXform = true;
12851       SmallVector<SDNode*, 4> SetCCs;
12852       if (!N0.hasOneUse()) {
12853         if (N0.getOpcode() == ISD::AND) {
12854           auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
12855           EVT LoadResultTy = AndC->getValueType(0);
12856           EVT ExtVT;
12857           if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
12858             DoXform = false;
12859         }
12860       }
12861       if (DoXform)
12862         DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
12863                                           ISD::ZERO_EXTEND, SetCCs, TLI);
12864       if (DoXform) {
12865         SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
12866                                          LN00->getChain(), LN00->getBasePtr(),
12867                                          LN00->getMemoryVT(),
12868                                          LN00->getMemOperand());
12869         APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
12870         SDLoc DL(N);
12871         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
12872                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
12873         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
12874         bool NoReplaceTruncAnd = !N0.hasOneUse();
12875         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
12876         CombineTo(N, And);
12877         // If N0 has multiple uses, change other uses as well.
12878         if (NoReplaceTruncAnd) {
12879           SDValue TruncAnd =
12880               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
12881           CombineTo(N0.getNode(), TruncAnd);
12882         }
12883         if (NoReplaceTrunc) {
12884           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
12885         } else {
12886           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
12887                                       LN00->getValueType(0), ExtLoad);
12888           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
12889         }
12890         return SDValue(N,0); // Return N so it doesn't get rechecked!
12891       }
12892     }
12893   }
12894 
12895   // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
12896   //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
12897   if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
12898     return ZExtLoad;
12899 
12900   // Try to simplify (zext (zextload x)).
12901   if (SDValue foldedExt = tryToFoldExtOfExtload(
12902           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
12903     return foldedExt;
12904 
12905   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
12906     return V;
12907 
12908   if (N0.getOpcode() == ISD::SETCC) {
12909     // Propagate fast-math-flags.
12910     SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
12911 
12912     // Only do this before legalize for now.
12913     if (!LegalOperations && VT.isVector() &&
12914         N0.getValueType().getVectorElementType() == MVT::i1) {
12915       EVT N00VT = N0.getOperand(0).getValueType();
12916       if (getSetCCResultType(N00VT) == N0.getValueType())
12917         return SDValue();
12918 
12919       // We know that the # elements of the results is the same as the #
12920       // elements of the compare (and the # elements of the compare result for
12921       // that matter). Check to see that they are the same size. If so, we know
12922       // that the element size of the sext'd result matches the element size of
12923       // the compare operands.
12924       SDLoc DL(N);
12925       if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
12926         // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
12927         SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
12928                                      N0.getOperand(1), N0.getOperand(2));
12929         return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType());
12930       }
12931 
12932       // If the desired elements are smaller or larger than the source
12933       // elements we can use a matching integer vector type and then
12934       // truncate/any extend followed by zext_in_reg.
12935       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
12936       SDValue VsetCC =
12937           DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
12938                       N0.getOperand(1), N0.getOperand(2));
12939       return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL,
12940                                     N0.getValueType());
12941     }
12942 
12943     // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
12944     SDLoc DL(N);
12945     EVT N0VT = N0.getValueType();
12946     EVT N00VT = N0.getOperand(0).getValueType();
12947     if (SDValue SCC = SimplifySelectCC(
12948             DL, N0.getOperand(0), N0.getOperand(1),
12949             DAG.getBoolConstant(true, DL, N0VT, N00VT),
12950             DAG.getBoolConstant(false, DL, N0VT, N00VT),
12951             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
12952       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, SCC);
12953   }
12954 
12955   // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
12956   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
12957       isa<ConstantSDNode>(N0.getOperand(1)) &&
12958       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
12959       N0.hasOneUse()) {
12960     SDValue ShAmt = N0.getOperand(1);
12961     if (N0.getOpcode() == ISD::SHL) {
12962       SDValue InnerZExt = N0.getOperand(0);
12963       // If the original shl may be shifting out bits, do not perform this
12964       // transformation.
12965       unsigned KnownZeroBits = InnerZExt.getValueSizeInBits() -
12966         InnerZExt.getOperand(0).getValueSizeInBits();
12967       if (cast<ConstantSDNode>(ShAmt)->getAPIntValue().ugt(KnownZeroBits))
12968         return SDValue();
12969     }
12970 
12971     SDLoc DL(N);
12972 
12973     // Ensure that the shift amount is wide enough for the shifted value.
12974     if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
12975       ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
12976 
12977     return DAG.getNode(N0.getOpcode(), DL, VT,
12978                        DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0)),
12979                        ShAmt);
12980   }
12981 
12982   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
12983     return NewVSel;
12984 
12985   if (SDValue NewCtPop = widenCtPop(N, DAG))
12986     return NewCtPop;
12987 
12988   if (SDValue V = widenAbs(N, DAG))
12989     return V;
12990 
12991   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
12992     return Res;
12993 
12994   return SDValue();
12995 }
12996 
visitANY_EXTEND(SDNode * N)12997 SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
12998   SDValue N0 = N->getOperand(0);
12999   EVT VT = N->getValueType(0);
13000 
13001   // aext(undef) = undef
13002   if (N0.isUndef())
13003     return DAG.getUNDEF(VT);
13004 
13005   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
13006     return Res;
13007 
13008   // fold (aext (aext x)) -> (aext x)
13009   // fold (aext (zext x)) -> (zext x)
13010   // fold (aext (sext x)) -> (sext x)
13011   if (N0.getOpcode() == ISD::ANY_EXTEND  ||
13012       N0.getOpcode() == ISD::ZERO_EXTEND ||
13013       N0.getOpcode() == ISD::SIGN_EXTEND)
13014     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
13015 
13016   // fold (aext (truncate (load x))) -> (aext (smaller load x))
13017   // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
13018   if (N0.getOpcode() == ISD::TRUNCATE) {
13019     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
13020       SDNode *oye = N0.getOperand(0).getNode();
13021       if (NarrowLoad.getNode() != N0.getNode()) {
13022         CombineTo(N0.getNode(), NarrowLoad);
13023         // CombineTo deleted the truncate, if needed, but not what's under it.
13024         AddToWorklist(oye);
13025       }
13026       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
13027     }
13028   }
13029 
13030   // fold (aext (truncate x))
13031   if (N0.getOpcode() == ISD::TRUNCATE)
13032     return DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
13033 
13034   // Fold (aext (and (trunc x), cst)) -> (and x, cst)
13035   // if the trunc is not free.
13036   if (N0.getOpcode() == ISD::AND &&
13037       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
13038       N0.getOperand(1).getOpcode() == ISD::Constant &&
13039       !TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
13040                           N0.getValueType())) {
13041     SDLoc DL(N);
13042     SDValue X = DAG.getAnyExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
13043     SDValue Y = DAG.getNode(ISD::ANY_EXTEND, DL, VT, N0.getOperand(1));
13044     assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
13045     return DAG.getNode(ISD::AND, DL, VT, X, Y);
13046   }
13047 
13048   // fold (aext (load x)) -> (aext (truncate (extload x)))
13049   // None of the supported targets knows how to perform load and any_ext
13050   // on vectors in one instruction, so attempt to fold to zext instead.
13051   if (VT.isVector()) {
13052     // Try to simplify (zext (load x)).
13053     if (SDValue foldedExt =
13054             tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
13055                                ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
13056       return foldedExt;
13057   } else if (ISD::isNON_EXTLoad(N0.getNode()) &&
13058              ISD::isUNINDEXEDLoad(N0.getNode()) &&
13059              TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
13060     bool DoXform = true;
13061     SmallVector<SDNode *, 4> SetCCs;
13062     if (!N0.hasOneUse())
13063       DoXform =
13064           ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs, TLI);
13065     if (DoXform) {
13066       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13067       SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
13068                                        LN0->getChain(), LN0->getBasePtr(),
13069                                        N0.getValueType(), LN0->getMemOperand());
13070       ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
13071       // If the load value is used only by N, replace it via CombineTo N.
13072       bool NoReplaceTrunc = N0.hasOneUse();
13073       CombineTo(N, ExtLoad);
13074       if (NoReplaceTrunc) {
13075         DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
13076         recursivelyDeleteUnusedNodes(LN0);
13077       } else {
13078         SDValue Trunc =
13079             DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
13080         CombineTo(LN0, Trunc, ExtLoad.getValue(1));
13081       }
13082       return SDValue(N, 0); // Return N so it doesn't get rechecked!
13083     }
13084   }
13085 
13086   // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
13087   // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
13088   // fold (aext ( extload x)) -> (aext (truncate (extload  x)))
13089   if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
13090       ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
13091     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13092     ISD::LoadExtType ExtType = LN0->getExtensionType();
13093     EVT MemVT = LN0->getMemoryVT();
13094     if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) {
13095       SDValue ExtLoad = DAG.getExtLoad(ExtType, SDLoc(N),
13096                                        VT, LN0->getChain(), LN0->getBasePtr(),
13097                                        MemVT, LN0->getMemOperand());
13098       CombineTo(N, ExtLoad);
13099       DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
13100       recursivelyDeleteUnusedNodes(LN0);
13101       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
13102     }
13103   }
13104 
13105   if (N0.getOpcode() == ISD::SETCC) {
13106     // Propagate fast-math-flags.
13107     SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
13108 
13109     // For vectors:
13110     // aext(setcc) -> vsetcc
13111     // aext(setcc) -> truncate(vsetcc)
13112     // aext(setcc) -> aext(vsetcc)
13113     // Only do this before legalize for now.
13114     if (VT.isVector() && !LegalOperations) {
13115       EVT N00VT = N0.getOperand(0).getValueType();
13116       if (getSetCCResultType(N00VT) == N0.getValueType())
13117         return SDValue();
13118 
13119       // We know that the # elements of the results is the same as the
13120       // # elements of the compare (and the # elements of the compare result
13121       // for that matter).  Check to see that they are the same size.  If so,
13122       // we know that the element size of the sext'd result matches the
13123       // element size of the compare operands.
13124       if (VT.getSizeInBits() == N00VT.getSizeInBits())
13125         return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0),
13126                              N0.getOperand(1),
13127                              cast<CondCodeSDNode>(N0.getOperand(2))->get());
13128 
13129       // If the desired elements are smaller or larger than the source
13130       // elements we can use a matching integer vector type and then
13131       // truncate/any extend
13132       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
13133       SDValue VsetCC =
13134         DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0),
13135                       N0.getOperand(1),
13136                       cast<CondCodeSDNode>(N0.getOperand(2))->get());
13137       return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT);
13138     }
13139 
13140     // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
13141     SDLoc DL(N);
13142     if (SDValue SCC = SimplifySelectCC(
13143             DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
13144             DAG.getConstant(0, DL, VT),
13145             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
13146       return SCC;
13147   }
13148 
13149   if (SDValue NewCtPop = widenCtPop(N, DAG))
13150     return NewCtPop;
13151 
13152   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
13153     return Res;
13154 
13155   return SDValue();
13156 }
13157 
visitAssertExt(SDNode * N)13158 SDValue DAGCombiner::visitAssertExt(SDNode *N) {
13159   unsigned Opcode = N->getOpcode();
13160   SDValue N0 = N->getOperand(0);
13161   SDValue N1 = N->getOperand(1);
13162   EVT AssertVT = cast<VTSDNode>(N1)->getVT();
13163 
13164   // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
13165   if (N0.getOpcode() == Opcode &&
13166       AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
13167     return N0;
13168 
13169   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
13170       N0.getOperand(0).getOpcode() == Opcode) {
13171     // We have an assert, truncate, assert sandwich. Make one stronger assert
13172     // by asserting on the smallest asserted type to the larger source type.
13173     // This eliminates the later assert:
13174     // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
13175     // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
13176     SDLoc DL(N);
13177     SDValue BigA = N0.getOperand(0);
13178     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
13179     EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
13180     SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
13181     SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
13182                                     BigA.getOperand(0), MinAssertVTVal);
13183     return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
13184   }
13185 
13186   // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
13187   // than X. Just move the AssertZext in front of the truncate and drop the
13188   // AssertSExt.
13189   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
13190       N0.getOperand(0).getOpcode() == ISD::AssertSext &&
13191       Opcode == ISD::AssertZext) {
13192     SDValue BigA = N0.getOperand(0);
13193     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
13194     if (AssertVT.bitsLT(BigA_AssertVT)) {
13195       SDLoc DL(N);
13196       SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
13197                                       BigA.getOperand(0), N1);
13198       return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
13199     }
13200   }
13201 
13202   return SDValue();
13203 }
13204 
visitAssertAlign(SDNode * N)13205 SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
13206   SDLoc DL(N);
13207 
13208   Align AL = cast<AssertAlignSDNode>(N)->getAlign();
13209   SDValue N0 = N->getOperand(0);
13210 
13211   // Fold (assertalign (assertalign x, AL0), AL1) ->
13212   // (assertalign x, max(AL0, AL1))
13213   if (auto *AAN = dyn_cast<AssertAlignSDNode>(N0))
13214     return DAG.getAssertAlign(DL, N0.getOperand(0),
13215                               std::max(AL, AAN->getAlign()));
13216 
13217   // In rare cases, there are trivial arithmetic ops in source operands. Sink
13218   // this assert down to source operands so that those arithmetic ops could be
13219   // exposed to the DAG combining.
13220   switch (N0.getOpcode()) {
13221   default:
13222     break;
13223   case ISD::ADD:
13224   case ISD::SUB: {
13225     unsigned AlignShift = Log2(AL);
13226     SDValue LHS = N0.getOperand(0);
13227     SDValue RHS = N0.getOperand(1);
13228     unsigned LHSAlignShift = DAG.computeKnownBits(LHS).countMinTrailingZeros();
13229     unsigned RHSAlignShift = DAG.computeKnownBits(RHS).countMinTrailingZeros();
13230     if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
13231       if (LHSAlignShift < AlignShift)
13232         LHS = DAG.getAssertAlign(DL, LHS, AL);
13233       if (RHSAlignShift < AlignShift)
13234         RHS = DAG.getAssertAlign(DL, RHS, AL);
13235       return DAG.getNode(N0.getOpcode(), DL, N0.getValueType(), LHS, RHS);
13236     }
13237     break;
13238   }
13239   }
13240 
13241   return SDValue();
13242 }
13243 
13244 /// If the result of a load is shifted/masked/truncated to an effectively
13245 /// narrower type, try to transform the load to a narrower type and/or
13246 /// use an extending load.
reduceLoadWidth(SDNode * N)13247 SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
13248   unsigned Opc = N->getOpcode();
13249 
13250   ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
13251   SDValue N0 = N->getOperand(0);
13252   EVT VT = N->getValueType(0);
13253   EVT ExtVT = VT;
13254 
13255   // This transformation isn't valid for vector loads.
13256   if (VT.isVector())
13257     return SDValue();
13258 
13259   // The ShAmt variable is used to indicate that we've consumed a right
13260   // shift. I.e. we want to narrow the width of the load by skipping to load the
13261   // ShAmt least significant bits.
13262   unsigned ShAmt = 0;
13263   // A special case is when the least significant bits from the load are masked
13264   // away, but using an AND rather than a right shift. HasShiftedOffset is used
13265   // to indicate that the narrowed load should be left-shifted ShAmt bits to get
13266   // the result.
13267   bool HasShiftedOffset = false;
13268   // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
13269   // extended to VT.
13270   if (Opc == ISD::SIGN_EXTEND_INREG) {
13271     ExtType = ISD::SEXTLOAD;
13272     ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
13273   } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
13274     // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
13275     // value, or it may be shifting a higher subword, half or byte into the
13276     // lowest bits.
13277 
13278     // Only handle shift with constant shift amount, and the shiftee must be a
13279     // load.
13280     auto *LN = dyn_cast<LoadSDNode>(N0);
13281     auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
13282     if (!N1C || !LN)
13283       return SDValue();
13284     // If the shift amount is larger than the memory type then we're not
13285     // accessing any of the loaded bytes.
13286     ShAmt = N1C->getZExtValue();
13287     uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
13288     if (MemoryWidth <= ShAmt)
13289       return SDValue();
13290     // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
13291     ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
13292     ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
13293     // If original load is a SEXTLOAD then we can't simply replace it by a
13294     // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
13295     // followed by a ZEXT, but that is not handled at the moment). Similarly if
13296     // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
13297     if ((LN->getExtensionType() == ISD::SEXTLOAD ||
13298          LN->getExtensionType() == ISD::ZEXTLOAD) &&
13299         LN->getExtensionType() != ExtType)
13300       return SDValue();
13301   } else if (Opc == ISD::AND) {
13302     // An AND with a constant mask is the same as a truncate + zero-extend.
13303     auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
13304     if (!AndC)
13305       return SDValue();
13306 
13307     const APInt &Mask = AndC->getAPIntValue();
13308     unsigned ActiveBits = 0;
13309     if (Mask.isMask()) {
13310       ActiveBits = Mask.countTrailingOnes();
13311     } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) {
13312       HasShiftedOffset = true;
13313     } else {
13314       return SDValue();
13315     }
13316 
13317     ExtType = ISD::ZEXTLOAD;
13318     ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
13319   }
13320 
13321   // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
13322   // a right shift. Here we redo some of those checks, to possibly adjust the
13323   // ExtVT even further based on "a masking AND". We could also end up here for
13324   // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
13325   // need to be done here as well.
13326   if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
13327     SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
13328     // Bail out when the SRL has more than one use. This is done for historical
13329     // (undocumented) reasons. Maybe intent was to guard the AND-masking below
13330     // check below? And maybe it could be non-profitable to do the transform in
13331     // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
13332     // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
13333     if (!SRL.hasOneUse())
13334       return SDValue();
13335 
13336     // Only handle shift with constant shift amount, and the shiftee must be a
13337     // load.
13338     auto *LN = dyn_cast<LoadSDNode>(SRL.getOperand(0));
13339     auto *SRL1C = dyn_cast<ConstantSDNode>(SRL.getOperand(1));
13340     if (!SRL1C || !LN)
13341       return SDValue();
13342 
13343     // If the shift amount is larger than the input type then we're not
13344     // accessing any of the loaded bytes.  If the load was a zextload/extload
13345     // then the result of the shift+trunc is zero/undef (handled elsewhere).
13346     ShAmt = SRL1C->getZExtValue();
13347     uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
13348     if (ShAmt >= MemoryWidth)
13349       return SDValue();
13350 
13351     // Because a SRL must be assumed to *need* to zero-extend the high bits
13352     // (as opposed to anyext the high bits), we can't combine the zextload
13353     // lowering of SRL and an sextload.
13354     if (LN->getExtensionType() == ISD::SEXTLOAD)
13355       return SDValue();
13356 
13357     // Avoid reading outside the memory accessed by the original load (could
13358     // happened if we only adjust the load base pointer by ShAmt). Instead we
13359     // try to narrow the load even further. The typical scenario here is:
13360     //   (i64 (truncate (i96 (srl (load x), 64)))) ->
13361     //     (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
13362     if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
13363       // Don't replace sextload by zextload.
13364       if (ExtType == ISD::SEXTLOAD)
13365         return SDValue();
13366       // Narrow the load.
13367       ExtType = ISD::ZEXTLOAD;
13368       ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
13369     }
13370 
13371     // If the SRL is only used by a masking AND, we may be able to adjust
13372     // the ExtVT to make the AND redundant.
13373     SDNode *Mask = *(SRL->use_begin());
13374     if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
13375         isa<ConstantSDNode>(Mask->getOperand(1))) {
13376       const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
13377       if (ShiftMask.isMask()) {
13378         EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
13379                                          ShiftMask.countTrailingOnes());
13380         // If the mask is smaller, recompute the type.
13381         if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
13382             TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT))
13383           ExtVT = MaskedVT;
13384       }
13385     }
13386 
13387     N0 = SRL.getOperand(0);
13388   }
13389 
13390   // If the load is shifted left (and the result isn't shifted back right), we
13391   // can fold a truncate through the shift. The typical scenario is that N
13392   // points at a TRUNCATE here so the attempted fold is:
13393   //   (truncate (shl (load x), c))) -> (shl (narrow load x), c)
13394   // ShLeftAmt will indicate how much a narrowed load should be shifted left.
13395   unsigned ShLeftAmt = 0;
13396   if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
13397       ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
13398     if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
13399       ShLeftAmt = N01->getZExtValue();
13400       N0 = N0.getOperand(0);
13401     }
13402   }
13403 
13404   // If we haven't found a load, we can't narrow it.
13405   if (!isa<LoadSDNode>(N0))
13406     return SDValue();
13407 
13408   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13409   // Reducing the width of a volatile load is illegal.  For atomics, we may be
13410   // able to reduce the width provided we never widen again. (see D66309)
13411   if (!LN0->isSimple() ||
13412       !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
13413     return SDValue();
13414 
13415   auto AdjustBigEndianShift = [&](unsigned ShAmt) {
13416     unsigned LVTStoreBits =
13417         LN0->getMemoryVT().getStoreSizeInBits().getFixedValue();
13418     unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedValue();
13419     return LVTStoreBits - EVTStoreBits - ShAmt;
13420   };
13421 
13422   // We need to adjust the pointer to the load by ShAmt bits in order to load
13423   // the correct bytes.
13424   unsigned PtrAdjustmentInBits =
13425       DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
13426 
13427   uint64_t PtrOff = PtrAdjustmentInBits / 8;
13428   Align NewAlign = commonAlignment(LN0->getAlign(), PtrOff);
13429   SDLoc DL(LN0);
13430   // The original load itself didn't wrap, so an offset within it doesn't.
13431   SDNodeFlags Flags;
13432   Flags.setNoUnsignedWrap(true);
13433   SDValue NewPtr = DAG.getMemBasePlusOffset(LN0->getBasePtr(),
13434                                             TypeSize::Fixed(PtrOff), DL, Flags);
13435   AddToWorklist(NewPtr.getNode());
13436 
13437   SDValue Load;
13438   if (ExtType == ISD::NON_EXTLOAD)
13439     Load = DAG.getLoad(VT, DL, LN0->getChain(), NewPtr,
13440                        LN0->getPointerInfo().getWithOffset(PtrOff), NewAlign,
13441                        LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
13442   else
13443     Load = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), NewPtr,
13444                           LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
13445                           NewAlign, LN0->getMemOperand()->getFlags(),
13446                           LN0->getAAInfo());
13447 
13448   // Replace the old load's chain with the new load's chain.
13449   WorklistRemover DeadNodes(*this);
13450   DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
13451 
13452   // Shift the result left, if we've swallowed a left shift.
13453   SDValue Result = Load;
13454   if (ShLeftAmt != 0) {
13455     EVT ShImmTy = getShiftAmountTy(Result.getValueType());
13456     if (!isUIntN(ShImmTy.getScalarSizeInBits(), ShLeftAmt))
13457       ShImmTy = VT;
13458     // If the shift amount is as large as the result size (but, presumably,
13459     // no larger than the source) then the useful bits of the result are
13460     // zero; we can't simply return the shortened shift, because the result
13461     // of that operation is undefined.
13462     if (ShLeftAmt >= VT.getScalarSizeInBits())
13463       Result = DAG.getConstant(0, DL, VT);
13464     else
13465       Result = DAG.getNode(ISD::SHL, DL, VT,
13466                           Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
13467   }
13468 
13469   if (HasShiftedOffset) {
13470     // We're using a shifted mask, so the load now has an offset. This means
13471     // that data has been loaded into the lower bytes than it would have been
13472     // before, so we need to shl the loaded data into the correct position in the
13473     // register.
13474     SDValue ShiftC = DAG.getConstant(ShAmt, DL, VT);
13475     Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
13476     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
13477   }
13478 
13479   // Return the new loaded value.
13480   return Result;
13481 }
13482 
visitSIGN_EXTEND_INREG(SDNode * N)13483 SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
13484   SDValue N0 = N->getOperand(0);
13485   SDValue N1 = N->getOperand(1);
13486   EVT VT = N->getValueType(0);
13487   EVT ExtVT = cast<VTSDNode>(N1)->getVT();
13488   unsigned VTBits = VT.getScalarSizeInBits();
13489   unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
13490 
13491   // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
13492   if (N0.isUndef())
13493     return DAG.getConstant(0, SDLoc(N), VT);
13494 
13495   // fold (sext_in_reg c1) -> c1
13496   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
13497     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1);
13498 
13499   // If the input is already sign extended, just drop the extension.
13500   if (ExtVTBits >= DAG.ComputeMaxSignificantBits(N0))
13501     return N0;
13502 
13503   // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
13504   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
13505       ExtVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT()))
13506     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0.getOperand(0),
13507                        N1);
13508 
13509   // fold (sext_in_reg (sext x)) -> (sext x)
13510   // fold (sext_in_reg (aext x)) -> (sext x)
13511   // if x is small enough or if we know that x has more than 1 sign bit and the
13512   // sign_extend_inreg is extending from one of them.
13513   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
13514     SDValue N00 = N0.getOperand(0);
13515     unsigned N00Bits = N00.getScalarValueSizeInBits();
13516     if ((N00Bits <= ExtVTBits ||
13517          DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits) &&
13518         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
13519       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00);
13520   }
13521 
13522   // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
13523   // if x is small enough or if we know that x has more than 1 sign bit and the
13524   // sign_extend_inreg is extending from one of them.
13525   if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13526       N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ||
13527       N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) {
13528     SDValue N00 = N0.getOperand(0);
13529     unsigned N00Bits = N00.getScalarValueSizeInBits();
13530     unsigned DstElts = N0.getValueType().getVectorMinNumElements();
13531     unsigned SrcElts = N00.getValueType().getVectorMinNumElements();
13532     bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
13533     APInt DemandedSrcElts = APInt::getLowBitsSet(SrcElts, DstElts);
13534     if ((N00Bits == ExtVTBits ||
13535          (!IsZext && (N00Bits < ExtVTBits ||
13536                       DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits))) &&
13537         (!LegalOperations ||
13538          TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
13539       return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT, N00);
13540   }
13541 
13542   // fold (sext_in_reg (zext x)) -> (sext x)
13543   // iff we are extending the source sign bit.
13544   if (N0.getOpcode() == ISD::ZERO_EXTEND) {
13545     SDValue N00 = N0.getOperand(0);
13546     if (N00.getScalarValueSizeInBits() == ExtVTBits &&
13547         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
13548       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1);
13549   }
13550 
13551   // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
13552   if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, ExtVTBits - 1)))
13553     return DAG.getZeroExtendInReg(N0, SDLoc(N), ExtVT);
13554 
13555   // fold operands of sext_in_reg based on knowledge that the top bits are not
13556   // demanded.
13557   if (SimplifyDemandedBits(SDValue(N, 0)))
13558     return SDValue(N, 0);
13559 
13560   // fold (sext_in_reg (load x)) -> (smaller sextload x)
13561   // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
13562   if (SDValue NarrowLoad = reduceLoadWidth(N))
13563     return NarrowLoad;
13564 
13565   // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
13566   // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
13567   // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
13568   if (N0.getOpcode() == ISD::SRL) {
13569     if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
13570       if (ShAmt->getAPIntValue().ule(VTBits - ExtVTBits)) {
13571         // We can turn this into an SRA iff the input to the SRL is already sign
13572         // extended enough.
13573         unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0));
13574         if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
13575           return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0),
13576                              N0.getOperand(1));
13577       }
13578   }
13579 
13580   // fold (sext_inreg (extload x)) -> (sextload x)
13581   // If sextload is not supported by target, we can only do the combine when
13582   // load has one use. Doing otherwise can block folding the extload with other
13583   // extends that the target does support.
13584   if (ISD::isEXTLoad(N0.getNode()) &&
13585       ISD::isUNINDEXEDLoad(N0.getNode()) &&
13586       ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
13587       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() &&
13588         N0.hasOneUse()) ||
13589        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
13590     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13591     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
13592                                      LN0->getChain(),
13593                                      LN0->getBasePtr(), ExtVT,
13594                                      LN0->getMemOperand());
13595     CombineTo(N, ExtLoad);
13596     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
13597     AddToWorklist(ExtLoad.getNode());
13598     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
13599   }
13600 
13601   // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
13602   if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
13603       N0.hasOneUse() &&
13604       ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
13605       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) &&
13606        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
13607     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13608     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
13609                                      LN0->getChain(),
13610                                      LN0->getBasePtr(), ExtVT,
13611                                      LN0->getMemOperand());
13612     CombineTo(N, ExtLoad);
13613     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
13614     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
13615   }
13616 
13617   // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
13618   // ignore it if the masked load is already sign extended
13619   if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0)) {
13620     if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
13621         Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
13622         TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) {
13623       SDValue ExtMaskedLoad = DAG.getMaskedLoad(
13624           VT, SDLoc(N), Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(),
13625           Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(),
13626           Ld->getAddressingMode(), ISD::SEXTLOAD, Ld->isExpandingLoad());
13627       CombineTo(N, ExtMaskedLoad);
13628       CombineTo(N0.getNode(), ExtMaskedLoad, ExtMaskedLoad.getValue(1));
13629       return SDValue(N, 0); // Return N so it doesn't get rechecked!
13630     }
13631   }
13632 
13633   // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
13634   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
13635     if (SDValue(GN0, 0).hasOneUse() &&
13636         ExtVT == GN0->getMemoryVT() &&
13637         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
13638       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
13639                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
13640 
13641       SDValue ExtLoad = DAG.getMaskedGather(
13642           DAG.getVTList(VT, MVT::Other), ExtVT, SDLoc(N), Ops,
13643           GN0->getMemOperand(), GN0->getIndexType(), ISD::SEXTLOAD);
13644 
13645       CombineTo(N, ExtLoad);
13646       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
13647       AddToWorklist(ExtLoad.getNode());
13648       return SDValue(N, 0); // Return N so it doesn't get rechecked!
13649     }
13650   }
13651 
13652   // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
13653   if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
13654     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
13655                                            N0.getOperand(1), false))
13656       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, BSwap, N1);
13657   }
13658 
13659   // Fold (iM_signext_inreg
13660   //        (extract_subvector (zext|anyext|sext iN_v to _) _)
13661   //        from iN)
13662   //      -> (extract_subvector (signext iN_v to iM))
13663   if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() &&
13664       ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
13665     SDValue InnerExt = N0.getOperand(0);
13666     EVT InnerExtVT = InnerExt->getValueType(0);
13667     SDValue Extendee = InnerExt->getOperand(0);
13668 
13669     if (ExtVTBits == Extendee.getValueType().getScalarSizeInBits() &&
13670         (!LegalOperations ||
13671          TLI.isOperationLegal(ISD::SIGN_EXTEND, InnerExtVT))) {
13672       SDValue SignExtExtendee =
13673           DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), InnerExtVT, Extendee);
13674       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, SignExtExtendee,
13675                          N0.getOperand(1));
13676     }
13677   }
13678 
13679   return SDValue();
13680 }
13681 
13682 static SDValue
foldExtendVectorInregToExtendOfSubvector(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalOperations)13683 foldExtendVectorInregToExtendOfSubvector(SDNode *N, const TargetLowering &TLI,
13684                                          SelectionDAG &DAG,
13685                                          bool LegalOperations) {
13686   unsigned InregOpcode = N->getOpcode();
13687   unsigned Opcode = DAG.getOpcode_EXTEND(InregOpcode);
13688 
13689   SDValue Src = N->getOperand(0);
13690   EVT VT = N->getValueType(0);
13691   EVT SrcVT = EVT::getVectorVT(*DAG.getContext(),
13692                                Src.getValueType().getVectorElementType(),
13693                                VT.getVectorElementCount());
13694 
13695   assert((InregOpcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
13696           InregOpcode == ISD::ZERO_EXTEND_VECTOR_INREG ||
13697           InregOpcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
13698          "Expected EXTEND_VECTOR_INREG dag node in input!");
13699 
13700   // Profitability check: our operand must be an one-use CONCAT_VECTORS.
13701   // FIXME: one-use check may be overly restrictive
13702   if (!Src.hasOneUse() || Src.getOpcode() != ISD::CONCAT_VECTORS)
13703     return SDValue();
13704 
13705   // Profitability check: we must be extending exactly one of it's operands.
13706   // FIXME: this is probably overly restrictive.
13707   Src = Src.getOperand(0);
13708   if (Src.getValueType() != SrcVT)
13709     return SDValue();
13710 
13711   if (LegalOperations && !TLI.isOperationLegal(Opcode, VT))
13712     return SDValue();
13713 
13714   return DAG.getNode(Opcode, SDLoc(N), VT, Src);
13715 }
13716 
visitEXTEND_VECTOR_INREG(SDNode * N)13717 SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
13718   SDValue N0 = N->getOperand(0);
13719   EVT VT = N->getValueType(0);
13720 
13721   if (N0.isUndef()) {
13722     // aext_vector_inreg(undef) = undef because the top bits are undefined.
13723     // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
13724     return N->getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG
13725                ? DAG.getUNDEF(VT)
13726                : DAG.getConstant(0, SDLoc(N), VT);
13727   }
13728 
13729   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
13730     return Res;
13731 
13732   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
13733     return SDValue(N, 0);
13734 
13735   if (SDValue R = foldExtendVectorInregToExtendOfSubvector(N, TLI, DAG,
13736                                                            LegalOperations))
13737     return R;
13738 
13739   return SDValue();
13740 }
13741 
visitTRUNCATE(SDNode * N)13742 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
13743   SDValue N0 = N->getOperand(0);
13744   EVT VT = N->getValueType(0);
13745   EVT SrcVT = N0.getValueType();
13746   bool isLE = DAG.getDataLayout().isLittleEndian();
13747 
13748   // noop truncate
13749   if (SrcVT == VT)
13750     return N0;
13751 
13752   // fold (truncate (truncate x)) -> (truncate x)
13753   if (N0.getOpcode() == ISD::TRUNCATE)
13754     return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
13755 
13756   // fold (truncate c1) -> c1
13757   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
13758     SDValue C = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0);
13759     if (C.getNode() != N)
13760       return C;
13761   }
13762 
13763   // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
13764   if (N0.getOpcode() == ISD::ZERO_EXTEND ||
13765       N0.getOpcode() == ISD::SIGN_EXTEND ||
13766       N0.getOpcode() == ISD::ANY_EXTEND) {
13767     // if the source is smaller than the dest, we still need an extend.
13768     if (N0.getOperand(0).getValueType().bitsLT(VT))
13769       return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
13770     // if the source is larger than the dest, than we just need the truncate.
13771     if (N0.getOperand(0).getValueType().bitsGT(VT))
13772       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
13773     // if the source and dest are the same type, we can drop both the extend
13774     // and the truncate.
13775     return N0.getOperand(0);
13776   }
13777 
13778   // Try to narrow a truncate-of-sext_in_reg to the destination type:
13779   // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
13780   if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
13781       N0.hasOneUse()) {
13782     SDValue X = N0.getOperand(0);
13783     SDValue ExtVal = N0.getOperand(1);
13784     EVT ExtVT = cast<VTSDNode>(ExtVal)->getVT();
13785     if (ExtVT.bitsLT(VT)) {
13786       SDValue TrX = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, X);
13787       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, TrX, ExtVal);
13788     }
13789   }
13790 
13791   // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
13792   if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
13793     return SDValue();
13794 
13795   // Fold extract-and-trunc into a narrow extract. For example:
13796   //   i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
13797   //   i32 y = TRUNCATE(i64 x)
13798   //        -- becomes --
13799   //   v16i8 b = BITCAST (v2i64 val)
13800   //   i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
13801   //
13802   // Note: We only run this optimization after type legalization (which often
13803   // creates this pattern) and before operation legalization after which
13804   // we need to be more careful about the vector instructions that we generate.
13805   if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
13806       LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
13807     EVT VecTy = N0.getOperand(0).getValueType();
13808     EVT ExTy = N0.getValueType();
13809     EVT TrTy = N->getValueType(0);
13810 
13811     auto EltCnt = VecTy.getVectorElementCount();
13812     unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
13813     auto NewEltCnt = EltCnt * SizeRatio;
13814 
13815     EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
13816     assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
13817 
13818     SDValue EltNo = N0->getOperand(1);
13819     if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
13820       int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue();
13821       int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
13822 
13823       SDLoc DL(N);
13824       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
13825                          DAG.getBitcast(NVT, N0.getOperand(0)),
13826                          DAG.getVectorIdxConstant(Index, DL));
13827     }
13828   }
13829 
13830   // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
13831   if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
13832     if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) &&
13833         TLI.isTruncateFree(SrcVT, VT)) {
13834       SDLoc SL(N0);
13835       SDValue Cond = N0.getOperand(0);
13836       SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
13837       SDValue TruncOp1 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(2));
13838       return DAG.getNode(ISD::SELECT, SDLoc(N), VT, Cond, TruncOp0, TruncOp1);
13839     }
13840   }
13841 
13842   // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
13843   if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
13844       (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
13845       TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
13846     SDValue Amt = N0.getOperand(1);
13847     KnownBits Known = DAG.computeKnownBits(Amt);
13848     unsigned Size = VT.getScalarSizeInBits();
13849     if (Known.countMaxActiveBits() <= Log2_32(Size)) {
13850       SDLoc SL(N);
13851       EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
13852 
13853       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
13854       if (AmtVT != Amt.getValueType()) {
13855         Amt = DAG.getZExtOrTrunc(Amt, SL, AmtVT);
13856         AddToWorklist(Amt.getNode());
13857       }
13858       return DAG.getNode(ISD::SHL, SL, VT, Trunc, Amt);
13859     }
13860   }
13861 
13862   if (SDValue V = foldSubToUSubSat(VT, N0.getNode()))
13863     return V;
13864 
13865   // Attempt to pre-truncate BUILD_VECTOR sources.
13866   if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
13867       TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType()) &&
13868       // Avoid creating illegal types if running after type legalizer.
13869       (!LegalTypes || TLI.isTypeLegal(VT.getScalarType()))) {
13870     SDLoc DL(N);
13871     EVT SVT = VT.getScalarType();
13872     SmallVector<SDValue, 8> TruncOps;
13873     for (const SDValue &Op : N0->op_values()) {
13874       SDValue TruncOp = DAG.getNode(ISD::TRUNCATE, DL, SVT, Op);
13875       TruncOps.push_back(TruncOp);
13876     }
13877     return DAG.getBuildVector(VT, DL, TruncOps);
13878   }
13879 
13880   // Fold a series of buildvector, bitcast, and truncate if possible.
13881   // For example fold
13882   //   (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
13883   //   (2xi32 (buildvector x, y)).
13884   if (Level == AfterLegalizeVectorOps && VT.isVector() &&
13885       N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
13886       N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR &&
13887       N0.getOperand(0).hasOneUse()) {
13888     SDValue BuildVect = N0.getOperand(0);
13889     EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
13890     EVT TruncVecEltTy = VT.getVectorElementType();
13891 
13892     // Check that the element types match.
13893     if (BuildVectEltTy == TruncVecEltTy) {
13894       // Now we only need to compute the offset of the truncated elements.
13895       unsigned BuildVecNumElts =  BuildVect.getNumOperands();
13896       unsigned TruncVecNumElts = VT.getVectorNumElements();
13897       unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
13898 
13899       assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
13900              "Invalid number of elements");
13901 
13902       SmallVector<SDValue, 8> Opnds;
13903       for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
13904         Opnds.push_back(BuildVect.getOperand(i));
13905 
13906       return DAG.getBuildVector(VT, SDLoc(N), Opnds);
13907     }
13908   }
13909 
13910   // fold (truncate (load x)) -> (smaller load x)
13911   // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
13912   if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
13913     if (SDValue Reduced = reduceLoadWidth(N))
13914       return Reduced;
13915 
13916     // Handle the case where the load remains an extending load even
13917     // after truncation.
13918     if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
13919       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13920       if (LN0->isSimple() && LN0->getMemoryVT().bitsLT(VT)) {
13921         SDValue NewLoad = DAG.getExtLoad(LN0->getExtensionType(), SDLoc(LN0),
13922                                          VT, LN0->getChain(), LN0->getBasePtr(),
13923                                          LN0->getMemoryVT(),
13924                                          LN0->getMemOperand());
13925         DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLoad.getValue(1));
13926         return NewLoad;
13927       }
13928     }
13929   }
13930 
13931   // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
13932   // where ... are all 'undef'.
13933   if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
13934     SmallVector<EVT, 8> VTs;
13935     SDValue V;
13936     unsigned Idx = 0;
13937     unsigned NumDefs = 0;
13938 
13939     for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
13940       SDValue X = N0.getOperand(i);
13941       if (!X.isUndef()) {
13942         V = X;
13943         Idx = i;
13944         NumDefs++;
13945       }
13946       // Stop if more than one members are non-undef.
13947       if (NumDefs > 1)
13948         break;
13949 
13950       VTs.push_back(EVT::getVectorVT(*DAG.getContext(),
13951                                      VT.getVectorElementType(),
13952                                      X.getValueType().getVectorElementCount()));
13953     }
13954 
13955     if (NumDefs == 0)
13956       return DAG.getUNDEF(VT);
13957 
13958     if (NumDefs == 1) {
13959       assert(V.getNode() && "The single defined operand is empty!");
13960       SmallVector<SDValue, 8> Opnds;
13961       for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
13962         if (i != Idx) {
13963           Opnds.push_back(DAG.getUNDEF(VTs[i]));
13964           continue;
13965         }
13966         SDValue NV = DAG.getNode(ISD::TRUNCATE, SDLoc(V), VTs[i], V);
13967         AddToWorklist(NV.getNode());
13968         Opnds.push_back(NV);
13969       }
13970       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Opnds);
13971     }
13972   }
13973 
13974   // Fold truncate of a bitcast of a vector to an extract of the low vector
13975   // element.
13976   //
13977   // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
13978   if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
13979     SDValue VecSrc = N0.getOperand(0);
13980     EVT VecSrcVT = VecSrc.getValueType();
13981     if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
13982         (!LegalOperations ||
13983          TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) {
13984       SDLoc SL(N);
13985 
13986       unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
13987       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, VT, VecSrc,
13988                          DAG.getVectorIdxConstant(Idx, SL));
13989     }
13990   }
13991 
13992   // Simplify the operands using demanded-bits information.
13993   if (SimplifyDemandedBits(SDValue(N, 0)))
13994     return SDValue(N, 0);
13995 
13996   // fold (truncate (extract_subvector(ext x))) ->
13997   //      (extract_subvector x)
13998   // TODO: This can be generalized to cover cases where the truncate and extract
13999   // do not fully cancel each other out.
14000   if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
14001     SDValue N00 = N0.getOperand(0);
14002     if (N00.getOpcode() == ISD::SIGN_EXTEND ||
14003         N00.getOpcode() == ISD::ZERO_EXTEND ||
14004         N00.getOpcode() == ISD::ANY_EXTEND) {
14005       if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
14006           VT.getVectorElementType())
14007         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
14008                            N00.getOperand(0), N0.getOperand(1));
14009     }
14010   }
14011 
14012   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
14013     return NewVSel;
14014 
14015   // Narrow a suitable binary operation with a non-opaque constant operand by
14016   // moving it ahead of the truncate. This is limited to pre-legalization
14017   // because targets may prefer a wider type during later combines and invert
14018   // this transform.
14019   switch (N0.getOpcode()) {
14020   case ISD::ADD:
14021   case ISD::SUB:
14022   case ISD::MUL:
14023   case ISD::AND:
14024   case ISD::OR:
14025   case ISD::XOR:
14026     if (!LegalOperations && N0.hasOneUse() &&
14027         (isConstantOrConstantVector(N0.getOperand(0), true) ||
14028          isConstantOrConstantVector(N0.getOperand(1), true))) {
14029       // TODO: We already restricted this to pre-legalization, but for vectors
14030       // we are extra cautious to not create an unsupported operation.
14031       // Target-specific changes are likely needed to avoid regressions here.
14032       if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
14033         SDLoc DL(N);
14034         SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
14035         SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
14036         return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR);
14037       }
14038     }
14039     break;
14040   case ISD::ADDE:
14041   case ISD::ADDCARRY:
14042     // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
14043     // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry)
14044     // When the adde's carry is not used.
14045     // We only do for addcarry before legalize operation
14046     if (((!LegalOperations && N0.getOpcode() == ISD::ADDCARRY) ||
14047          TLI.isOperationLegal(N0.getOpcode(), VT)) &&
14048         N0.hasOneUse() && !N0->hasAnyUseOfValue(1)) {
14049       SDLoc DL(N);
14050       SDValue X = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
14051       SDValue Y = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
14052       SDVTList VTs = DAG.getVTList(VT, N0->getValueType(1));
14053       return DAG.getNode(N0.getOpcode(), DL, VTs, X, Y, N0.getOperand(2));
14054     }
14055     break;
14056   case ISD::USUBSAT:
14057     // Truncate the USUBSAT only if LHS is a known zero-extension, its not
14058     // enough to know that the upper bits are zero we must ensure that we don't
14059     // introduce an extra truncate.
14060     if (!LegalOperations && N0.hasOneUse() &&
14061         N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
14062         N0.getOperand(0).getOperand(0).getScalarValueSizeInBits() <=
14063             VT.getScalarSizeInBits() &&
14064         hasOperation(N0.getOpcode(), VT)) {
14065       return getTruncatedUSUBSAT(VT, SrcVT, N0.getOperand(0), N0.getOperand(1),
14066                                  DAG, SDLoc(N));
14067     }
14068     break;
14069   }
14070 
14071   return SDValue();
14072 }
14073 
getBuildPairElt(SDNode * N,unsigned i)14074 static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
14075   SDValue Elt = N->getOperand(i);
14076   if (Elt.getOpcode() != ISD::MERGE_VALUES)
14077     return Elt.getNode();
14078   return Elt.getOperand(Elt.getResNo()).getNode();
14079 }
14080 
14081 /// build_pair (load, load) -> load
14082 /// if load locations are consecutive.
CombineConsecutiveLoads(SDNode * N,EVT VT)14083 SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
14084   assert(N->getOpcode() == ISD::BUILD_PAIR);
14085 
14086   auto *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
14087   auto *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
14088 
14089   // A BUILD_PAIR is always having the least significant part in elt 0 and the
14090   // most significant part in elt 1. So when combining into one large load, we
14091   // need to consider the endianness.
14092   if (DAG.getDataLayout().isBigEndian())
14093     std::swap(LD1, LD2);
14094 
14095   if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !ISD::isNON_EXTLoad(LD2) ||
14096       !LD1->hasOneUse() || !LD2->hasOneUse() ||
14097       LD1->getAddressSpace() != LD2->getAddressSpace())
14098     return SDValue();
14099 
14100   unsigned LD1Fast = 0;
14101   EVT LD1VT = LD1->getValueType(0);
14102   unsigned LD1Bytes = LD1VT.getStoreSize();
14103   if ((!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)) &&
14104       DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1) &&
14105       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
14106                              *LD1->getMemOperand(), &LD1Fast) && LD1Fast)
14107     return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
14108                        LD1->getPointerInfo(), LD1->getAlign());
14109 
14110   return SDValue();
14111 }
14112 
getPPCf128HiElementSelector(const SelectionDAG & DAG)14113 static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
14114   // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
14115   // and Lo parts; on big-endian machines it doesn't.
14116   return DAG.getDataLayout().isBigEndian() ? 1 : 0;
14117 }
14118 
foldBitcastedFPLogic(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)14119 static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
14120                                     const TargetLowering &TLI) {
14121   // If this is not a bitcast to an FP type or if the target doesn't have
14122   // IEEE754-compliant FP logic, we're done.
14123   EVT VT = N->getValueType(0);
14124   if (!VT.isFloatingPoint() || !TLI.hasBitPreservingFPLogic(VT))
14125     return SDValue();
14126 
14127   // TODO: Handle cases where the integer constant is a different scalar
14128   // bitwidth to the FP.
14129   SDValue N0 = N->getOperand(0);
14130   EVT SourceVT = N0.getValueType();
14131   if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
14132     return SDValue();
14133 
14134   unsigned FPOpcode;
14135   APInt SignMask;
14136   switch (N0.getOpcode()) {
14137   case ISD::AND:
14138     FPOpcode = ISD::FABS;
14139     SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
14140     break;
14141   case ISD::XOR:
14142     FPOpcode = ISD::FNEG;
14143     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
14144     break;
14145   case ISD::OR:
14146     FPOpcode = ISD::FABS;
14147     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
14148     break;
14149   default:
14150     return SDValue();
14151   }
14152 
14153   // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
14154   // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
14155   // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
14156   //   fneg (fabs X)
14157   SDValue LogicOp0 = N0.getOperand(0);
14158   ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
14159   if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
14160       LogicOp0.getOpcode() == ISD::BITCAST &&
14161       LogicOp0.getOperand(0).getValueType() == VT) {
14162     SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, LogicOp0.getOperand(0));
14163     NumFPLogicOpsConv++;
14164     if (N0.getOpcode() == ISD::OR)
14165       return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
14166     return FPOp;
14167   }
14168 
14169   return SDValue();
14170 }
14171 
visitBITCAST(SDNode * N)14172 SDValue DAGCombiner::visitBITCAST(SDNode *N) {
14173   SDValue N0 = N->getOperand(0);
14174   EVT VT = N->getValueType(0);
14175 
14176   if (N0.isUndef())
14177     return DAG.getUNDEF(VT);
14178 
14179   // If the input is a BUILD_VECTOR with all constant elements, fold this now.
14180   // Only do this before legalize types, unless both types are integer and the
14181   // scalar type is legal. Only do this before legalize ops, since the target
14182   // maybe depending on the bitcast.
14183   // First check to see if this is all constant.
14184   // TODO: Support FP bitcasts after legalize types.
14185   if (VT.isVector() &&
14186       (!LegalTypes ||
14187        (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
14188         TLI.isTypeLegal(VT.getVectorElementType()))) &&
14189       N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
14190       cast<BuildVectorSDNode>(N0)->isConstant())
14191     return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
14192                                              VT.getVectorElementType());
14193 
14194   // If the input is a constant, let getNode fold it.
14195   if (isIntOrFPConstant(N0)) {
14196     // If we can't allow illegal operations, we need to check that this is just
14197     // a fp -> int or int -> conversion and that the resulting operation will
14198     // be legal.
14199     if (!LegalOperations ||
14200         (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
14201          TLI.isOperationLegal(ISD::ConstantFP, VT)) ||
14202         (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
14203          TLI.isOperationLegal(ISD::Constant, VT))) {
14204       SDValue C = DAG.getBitcast(VT, N0);
14205       if (C.getNode() != N)
14206         return C;
14207     }
14208   }
14209 
14210   // (conv (conv x, t1), t2) -> (conv x, t2)
14211   if (N0.getOpcode() == ISD::BITCAST)
14212     return DAG.getBitcast(VT, N0.getOperand(0));
14213 
14214   // fold (conv (load x)) -> (load (conv*)x)
14215   // If the resultant load doesn't need a higher alignment than the original!
14216   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
14217       // Do not remove the cast if the types differ in endian layout.
14218       TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
14219           TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
14220       // If the load is volatile, we only want to change the load type if the
14221       // resulting load is legal. Otherwise we might increase the number of
14222       // memory accesses. We don't care if the original type was legal or not
14223       // as we assume software couldn't rely on the number of accesses of an
14224       // illegal type.
14225       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
14226        TLI.isOperationLegal(ISD::LOAD, VT))) {
14227     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14228 
14229     if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
14230                                     *LN0->getMemOperand())) {
14231       SDValue Load =
14232           DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
14233                       LN0->getPointerInfo(), LN0->getAlign(),
14234                       LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
14235       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
14236       return Load;
14237     }
14238   }
14239 
14240   if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
14241     return V;
14242 
14243   // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
14244   // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
14245   //
14246   // For ppc_fp128:
14247   // fold (bitcast (fneg x)) ->
14248   //     flipbit = signbit
14249   //     (xor (bitcast x) (build_pair flipbit, flipbit))
14250   //
14251   // fold (bitcast (fabs x)) ->
14252   //     flipbit = (and (extract_element (bitcast x), 0), signbit)
14253   //     (xor (bitcast x) (build_pair flipbit, flipbit))
14254   // This often reduces constant pool loads.
14255   if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
14256        (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
14257       N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
14258       !N0.getValueType().isVector()) {
14259     SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
14260     AddToWorklist(NewConv.getNode());
14261 
14262     SDLoc DL(N);
14263     if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
14264       assert(VT.getSizeInBits() == 128);
14265       SDValue SignBit = DAG.getConstant(
14266           APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
14267       SDValue FlipBit;
14268       if (N0.getOpcode() == ISD::FNEG) {
14269         FlipBit = SignBit;
14270         AddToWorklist(FlipBit.getNode());
14271       } else {
14272         assert(N0.getOpcode() == ISD::FABS);
14273         SDValue Hi =
14274             DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
14275                         DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
14276                                               SDLoc(NewConv)));
14277         AddToWorklist(Hi.getNode());
14278         FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
14279         AddToWorklist(FlipBit.getNode());
14280       }
14281       SDValue FlipBits =
14282           DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
14283       AddToWorklist(FlipBits.getNode());
14284       return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
14285     }
14286     APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
14287     if (N0.getOpcode() == ISD::FNEG)
14288       return DAG.getNode(ISD::XOR, DL, VT,
14289                          NewConv, DAG.getConstant(SignBit, DL, VT));
14290     assert(N0.getOpcode() == ISD::FABS);
14291     return DAG.getNode(ISD::AND, DL, VT,
14292                        NewConv, DAG.getConstant(~SignBit, DL, VT));
14293   }
14294 
14295   // fold (bitconvert (fcopysign cst, x)) ->
14296   //         (or (and (bitconvert x), sign), (and cst, (not sign)))
14297   // Note that we don't handle (copysign x, cst) because this can always be
14298   // folded to an fneg or fabs.
14299   //
14300   // For ppc_fp128:
14301   // fold (bitcast (fcopysign cst, x)) ->
14302   //     flipbit = (and (extract_element
14303   //                     (xor (bitcast cst), (bitcast x)), 0),
14304   //                    signbit)
14305   //     (xor (bitcast cst) (build_pair flipbit, flipbit))
14306   if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
14307       isa<ConstantFPSDNode>(N0.getOperand(0)) && VT.isInteger() &&
14308       !VT.isVector()) {
14309     unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
14310     EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
14311     if (isTypeLegal(IntXVT)) {
14312       SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
14313       AddToWorklist(X.getNode());
14314 
14315       // If X has a different width than the result/lhs, sext it or truncate it.
14316       unsigned VTWidth = VT.getSizeInBits();
14317       if (OrigXWidth < VTWidth) {
14318         X = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, X);
14319         AddToWorklist(X.getNode());
14320       } else if (OrigXWidth > VTWidth) {
14321         // To get the sign bit in the right place, we have to shift it right
14322         // before truncating.
14323         SDLoc DL(X);
14324         X = DAG.getNode(ISD::SRL, DL,
14325                         X.getValueType(), X,
14326                         DAG.getConstant(OrigXWidth-VTWidth, DL,
14327                                         X.getValueType()));
14328         AddToWorklist(X.getNode());
14329         X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
14330         AddToWorklist(X.getNode());
14331       }
14332 
14333       if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
14334         APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
14335         SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
14336         AddToWorklist(Cst.getNode());
14337         SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
14338         AddToWorklist(X.getNode());
14339         SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
14340         AddToWorklist(XorResult.getNode());
14341         SDValue XorResult64 = DAG.getNode(
14342             ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
14343             DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
14344                                   SDLoc(XorResult)));
14345         AddToWorklist(XorResult64.getNode());
14346         SDValue FlipBit =
14347             DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
14348                         DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
14349         AddToWorklist(FlipBit.getNode());
14350         SDValue FlipBits =
14351             DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
14352         AddToWorklist(FlipBits.getNode());
14353         return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
14354       }
14355       APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
14356       X = DAG.getNode(ISD::AND, SDLoc(X), VT,
14357                       X, DAG.getConstant(SignBit, SDLoc(X), VT));
14358       AddToWorklist(X.getNode());
14359 
14360       SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
14361       Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
14362                         Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
14363       AddToWorklist(Cst.getNode());
14364 
14365       return DAG.getNode(ISD::OR, SDLoc(N), VT, X, Cst);
14366     }
14367   }
14368 
14369   // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
14370   if (N0.getOpcode() == ISD::BUILD_PAIR)
14371     if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
14372       return CombineLD;
14373 
14374   // Remove double bitcasts from shuffles - this is often a legacy of
14375   // XformToShuffleWithZero being used to combine bitmaskings (of
14376   // float vectors bitcast to integer vectors) into shuffles.
14377   // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
14378   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
14379       N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
14380       VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
14381       !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
14382     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
14383 
14384     // If operands are a bitcast, peek through if it casts the original VT.
14385     // If operands are a constant, just bitcast back to original VT.
14386     auto PeekThroughBitcast = [&](SDValue Op) {
14387       if (Op.getOpcode() == ISD::BITCAST &&
14388           Op.getOperand(0).getValueType() == VT)
14389         return SDValue(Op.getOperand(0));
14390       if (Op.isUndef() || isAnyConstantBuildVector(Op))
14391         return DAG.getBitcast(VT, Op);
14392       return SDValue();
14393     };
14394 
14395     // FIXME: If either input vector is bitcast, try to convert the shuffle to
14396     // the result type of this bitcast. This would eliminate at least one
14397     // bitcast. See the transform in InstCombine.
14398     SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
14399     SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
14400     if (!(SV0 && SV1))
14401       return SDValue();
14402 
14403     int MaskScale =
14404         VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
14405     SmallVector<int, 8> NewMask;
14406     for (int M : SVN->getMask())
14407       for (int i = 0; i != MaskScale; ++i)
14408         NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
14409 
14410     SDValue LegalShuffle =
14411         TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
14412     if (LegalShuffle)
14413       return LegalShuffle;
14414   }
14415 
14416   return SDValue();
14417 }
14418 
visitBUILD_PAIR(SDNode * N)14419 SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
14420   EVT VT = N->getValueType(0);
14421   return CombineConsecutiveLoads(N, VT);
14422 }
14423 
visitFREEZE(SDNode * N)14424 SDValue DAGCombiner::visitFREEZE(SDNode *N) {
14425   SDValue N0 = N->getOperand(0);
14426 
14427   if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, /*PoisonOnly*/ false))
14428     return N0;
14429 
14430   // Fold freeze(op(x, ...)) -> op(freeze(x), ...).
14431   // Try to push freeze through instructions that propagate but don't produce
14432   // poison as far as possible. If an operand of freeze follows three
14433   // conditions 1) one-use, 2) does not produce poison, and 3) has all but one
14434   // guaranteed-non-poison operands (or is a BUILD_VECTOR or similar) then push
14435   // the freeze through to the operands that are not guaranteed non-poison.
14436   // NOTE: we will strip poison-generating flags, so ignore them here.
14437   if (DAG.canCreateUndefOrPoison(N0, /*PoisonOnly*/ false,
14438                                  /*ConsiderFlags*/ false) ||
14439       N0->getNumValues() != 1 || !N0->hasOneUse())
14440     return SDValue();
14441 
14442   bool AllowMultipleMaybePoisonOperands = N0.getOpcode() == ISD::BUILD_VECTOR;
14443 
14444   SmallSetVector<SDValue, 8> MaybePoisonOperands;
14445   for (SDValue Op : N0->ops()) {
14446     if (DAG.isGuaranteedNotToBeUndefOrPoison(Op, /*PoisonOnly*/ false,
14447                                              /*Depth*/ 1))
14448       continue;
14449     bool HadMaybePoisonOperands = !MaybePoisonOperands.empty();
14450     bool IsNewMaybePoisonOperand = MaybePoisonOperands.insert(Op);
14451     if (!HadMaybePoisonOperands)
14452       continue;
14453     if (IsNewMaybePoisonOperand && !AllowMultipleMaybePoisonOperands) {
14454       // Multiple maybe-poison ops when not allowed - bail out.
14455       return SDValue();
14456     }
14457   }
14458   // NOTE: the whole op may be not guaranteed to not be undef or poison because
14459   // it could create undef or poison due to it's poison-generating flags.
14460   // So not finding any maybe-poison operands is fine.
14461 
14462   for (SDValue MaybePoisonOperand : MaybePoisonOperands) {
14463     // Don't replace every single UNDEF everywhere with frozen UNDEF, though.
14464     if (MaybePoisonOperand.getOpcode() == ISD::UNDEF)
14465       continue;
14466     // First, freeze each offending operand.
14467     SDValue FrozenMaybePoisonOperand = DAG.getFreeze(MaybePoisonOperand);
14468     // Then, change all other uses of unfrozen operand to use frozen operand.
14469     DAG.ReplaceAllUsesOfValueWith(MaybePoisonOperand, FrozenMaybePoisonOperand);
14470     if (FrozenMaybePoisonOperand.getOpcode() == ISD::FREEZE &&
14471         FrozenMaybePoisonOperand.getOperand(0) == FrozenMaybePoisonOperand) {
14472       // But, that also updated the use in the freeze we just created, thus
14473       // creating a cycle in a DAG. Let's undo that by mutating the freeze.
14474       DAG.UpdateNodeOperands(FrozenMaybePoisonOperand.getNode(),
14475                              MaybePoisonOperand);
14476     }
14477   }
14478 
14479   // The whole node may have been updated, so the value we were holding
14480   // may no longer be valid. Re-fetch the operand we're `freeze`ing.
14481   N0 = N->getOperand(0);
14482 
14483   // Finally, recreate the node, it's operands were updated to use
14484   // frozen operands, so we just need to use it's "original" operands.
14485   SmallVector<SDValue> Ops(N0->op_begin(), N0->op_end());
14486   // Special-handle ISD::UNDEF, each single one of them can be it's own thing.
14487   for (SDValue &Op : Ops) {
14488     if (Op.getOpcode() == ISD::UNDEF)
14489       Op = DAG.getFreeze(Op);
14490   }
14491   // NOTE: this strips poison generating flags.
14492   SDValue R = DAG.getNode(N0.getOpcode(), SDLoc(N0), N0->getVTList(), Ops);
14493   assert(DAG.isGuaranteedNotToBeUndefOrPoison(R, /*PoisonOnly*/ false) &&
14494          "Can't create node that may be undef/poison!");
14495   return R;
14496 }
14497 
14498 /// We know that BV is a build_vector node with Constant, ConstantFP or Undef
14499 /// operands. DstEltVT indicates the destination element value type.
14500 SDValue DAGCombiner::
ConstantFoldBITCASTofBUILD_VECTOR(SDNode * BV,EVT DstEltVT)14501 ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
14502   EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
14503 
14504   // If this is already the right type, we're done.
14505   if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
14506 
14507   unsigned SrcBitSize = SrcEltVT.getSizeInBits();
14508   unsigned DstBitSize = DstEltVT.getSizeInBits();
14509 
14510   // If this is a conversion of N elements of one type to N elements of another
14511   // type, convert each element.  This handles FP<->INT cases.
14512   if (SrcBitSize == DstBitSize) {
14513     SmallVector<SDValue, 8> Ops;
14514     for (SDValue Op : BV->op_values()) {
14515       // If the vector element type is not legal, the BUILD_VECTOR operands
14516       // are promoted and implicitly truncated.  Make that explicit here.
14517       if (Op.getValueType() != SrcEltVT)
14518         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op);
14519       Ops.push_back(DAG.getBitcast(DstEltVT, Op));
14520       AddToWorklist(Ops.back().getNode());
14521     }
14522     EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
14523                               BV->getValueType(0).getVectorNumElements());
14524     return DAG.getBuildVector(VT, SDLoc(BV), Ops);
14525   }
14526 
14527   // Otherwise, we're growing or shrinking the elements.  To avoid having to
14528   // handle annoying details of growing/shrinking FP values, we convert them to
14529   // int first.
14530   if (SrcEltVT.isFloatingPoint()) {
14531     // Convert the input float vector to a int vector where the elements are the
14532     // same sizes.
14533     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltVT.getSizeInBits());
14534     BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, IntVT).getNode();
14535     SrcEltVT = IntVT;
14536   }
14537 
14538   // Now we know the input is an integer vector.  If the output is a FP type,
14539   // convert to integer first, then to FP of the right size.
14540   if (DstEltVT.isFloatingPoint()) {
14541     EVT TmpVT = EVT::getIntegerVT(*DAG.getContext(), DstEltVT.getSizeInBits());
14542     SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, TmpVT).getNode();
14543 
14544     // Next, convert to FP elements of the same size.
14545     return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT);
14546   }
14547 
14548   // Okay, we know the src/dst types are both integers of differing types.
14549   assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
14550 
14551   // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a
14552   // BuildVectorSDNode?
14553   auto *BVN = cast<BuildVectorSDNode>(BV);
14554 
14555   // Extract the constant raw bit data.
14556   BitVector UndefElements;
14557   SmallVector<APInt> RawBits;
14558   bool IsLE = DAG.getDataLayout().isLittleEndian();
14559   if (!BVN->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements))
14560     return SDValue();
14561 
14562   SDLoc DL(BV);
14563   SmallVector<SDValue, 8> Ops;
14564   for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
14565     if (UndefElements[I])
14566       Ops.push_back(DAG.getUNDEF(DstEltVT));
14567     else
14568       Ops.push_back(DAG.getConstant(RawBits[I], DL, DstEltVT));
14569   }
14570 
14571   EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
14572   return DAG.getBuildVector(VT, DL, Ops);
14573 }
14574 
14575 // Returns true if floating point contraction is allowed on the FMUL-SDValue
14576 // `N`
isContractableFMUL(const TargetOptions & Options,SDValue N)14577 static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
14578   assert(N.getOpcode() == ISD::FMUL);
14579 
14580   return Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath ||
14581          N->getFlags().hasAllowContract();
14582 }
14583 
14584 // Returns true if `N` can assume no infinities involved in its computation.
hasNoInfs(const TargetOptions & Options,SDValue N)14585 static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
14586   return Options.NoInfsFPMath || N->getFlags().hasNoInfs();
14587 }
14588 
14589 /// Try to perform FMA combining on a given FADD node.
visitFADDForFMACombine(SDNode * N)14590 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
14591   SDValue N0 = N->getOperand(0);
14592   SDValue N1 = N->getOperand(1);
14593   EVT VT = N->getValueType(0);
14594   SDLoc SL(N);
14595 
14596   const TargetOptions &Options = DAG.getTarget().Options;
14597 
14598   // Floating-point multiply-add with intermediate rounding.
14599   bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N));
14600 
14601   // Floating-point multiply-add without intermediate rounding.
14602   bool HasFMA =
14603       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
14604       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
14605 
14606   // No valid opcode, do not combine.
14607   if (!HasFMAD && !HasFMA)
14608     return SDValue();
14609 
14610   bool CanReassociate =
14611       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
14612   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
14613                               Options.UnsafeFPMath || HasFMAD);
14614   // If the addition is not contractable, do not combine.
14615   if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
14616     return SDValue();
14617 
14618   if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
14619     return SDValue();
14620 
14621   // Always prefer FMAD to FMA for precision.
14622   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
14623   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
14624 
14625   auto isFusedOp = [&](SDValue N) {
14626     unsigned Opcode = N.getOpcode();
14627     return Opcode == ISD::FMA || Opcode == ISD::FMAD;
14628   };
14629 
14630   // Is the node an FMUL and contractable either due to global flags or
14631   // SDNodeFlags.
14632   auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
14633     if (N.getOpcode() != ISD::FMUL)
14634       return false;
14635     return AllowFusionGlobally || N->getFlags().hasAllowContract();
14636   };
14637   // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
14638   // prefer to fold the multiply with fewer uses.
14639   if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
14640     if (N0->use_size() > N1->use_size())
14641       std::swap(N0, N1);
14642   }
14643 
14644   // fold (fadd (fmul x, y), z) -> (fma x, y, z)
14645   if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
14646     return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
14647                        N0.getOperand(1), N1);
14648   }
14649 
14650   // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
14651   // Note: Commutes FADD operands.
14652   if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
14653     return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0),
14654                        N1.getOperand(1), N0);
14655   }
14656 
14657   // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
14658   // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
14659   // This also works with nested fma instructions:
14660   // fadd (fma A, B, (fma (C, D, (fmul (E, F))))), G -->
14661   // fma A, B, (fma C, D, fma (E, F, G))
14662   // fadd (G, (fma A, B, (fma (C, D, (fmul (E, F)))))) -->
14663   // fma A, B, (fma C, D, fma (E, F, G)).
14664   // This requires reassociation because it changes the order of operations.
14665   if (CanReassociate) {
14666     SDValue FMA, E;
14667     if (isFusedOp(N0) && N0.hasOneUse()) {
14668       FMA = N0;
14669       E = N1;
14670     } else if (isFusedOp(N1) && N1.hasOneUse()) {
14671       FMA = N1;
14672       E = N0;
14673     }
14674 
14675     SDValue TmpFMA = FMA;
14676     while (E && isFusedOp(TmpFMA) && TmpFMA.hasOneUse()) {
14677       SDValue FMul = TmpFMA->getOperand(2);
14678       if (FMul.getOpcode() == ISD::FMUL && FMul.hasOneUse()) {
14679         SDValue C = FMul.getOperand(0);
14680         SDValue D = FMul.getOperand(1);
14681         SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
14682         DAG.ReplaceAllUsesOfValueWith(FMul, CDE);
14683         // Replacing the inner FMul could cause the outer FMA to be simplified
14684         // away.
14685         return FMA.getOpcode() == ISD::DELETED_NODE ? SDValue() : FMA;
14686       }
14687 
14688       TmpFMA = TmpFMA->getOperand(2);
14689     }
14690   }
14691 
14692   // Look through FP_EXTEND nodes to do more combining.
14693 
14694   // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
14695   if (N0.getOpcode() == ISD::FP_EXTEND) {
14696     SDValue N00 = N0.getOperand(0);
14697     if (isContractableFMUL(N00) &&
14698         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14699                             N00.getValueType())) {
14700       return DAG.getNode(PreferredFusedOpcode, SL, VT,
14701                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
14702                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
14703                          N1);
14704     }
14705   }
14706 
14707   // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
14708   // Note: Commutes FADD operands.
14709   if (N1.getOpcode() == ISD::FP_EXTEND) {
14710     SDValue N10 = N1.getOperand(0);
14711     if (isContractableFMUL(N10) &&
14712         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14713                             N10.getValueType())) {
14714       return DAG.getNode(PreferredFusedOpcode, SL, VT,
14715                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)),
14716                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)),
14717                          N0);
14718     }
14719   }
14720 
14721   // More folding opportunities when target permits.
14722   if (Aggressive) {
14723     // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
14724     //   -> (fma x, y, (fma (fpext u), (fpext v), z))
14725     auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
14726                                     SDValue Z) {
14727       return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y,
14728                          DAG.getNode(PreferredFusedOpcode, SL, VT,
14729                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
14730                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
14731                                      Z));
14732     };
14733     if (isFusedOp(N0)) {
14734       SDValue N02 = N0.getOperand(2);
14735       if (N02.getOpcode() == ISD::FP_EXTEND) {
14736         SDValue N020 = N02.getOperand(0);
14737         if (isContractableFMUL(N020) &&
14738             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14739                                 N020.getValueType())) {
14740           return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
14741                                       N020.getOperand(0), N020.getOperand(1),
14742                                       N1);
14743         }
14744       }
14745     }
14746 
14747     // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
14748     //   -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
14749     // FIXME: This turns two single-precision and one double-precision
14750     // operation into two double-precision operations, which might not be
14751     // interesting for all targets, especially GPUs.
14752     auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
14753                                     SDValue Z) {
14754       return DAG.getNode(
14755           PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, X),
14756           DAG.getNode(ISD::FP_EXTEND, SL, VT, Y),
14757           DAG.getNode(PreferredFusedOpcode, SL, VT,
14758                       DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
14759                       DAG.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
14760     };
14761     if (N0.getOpcode() == ISD::FP_EXTEND) {
14762       SDValue N00 = N0.getOperand(0);
14763       if (isFusedOp(N00)) {
14764         SDValue N002 = N00.getOperand(2);
14765         if (isContractableFMUL(N002) &&
14766             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14767                                 N00.getValueType())) {
14768           return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
14769                                       N002.getOperand(0), N002.getOperand(1),
14770                                       N1);
14771         }
14772       }
14773     }
14774 
14775     // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
14776     //   -> (fma y, z, (fma (fpext u), (fpext v), x))
14777     if (isFusedOp(N1)) {
14778       SDValue N12 = N1.getOperand(2);
14779       if (N12.getOpcode() == ISD::FP_EXTEND) {
14780         SDValue N120 = N12.getOperand(0);
14781         if (isContractableFMUL(N120) &&
14782             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14783                                 N120.getValueType())) {
14784           return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
14785                                       N120.getOperand(0), N120.getOperand(1),
14786                                       N0);
14787         }
14788       }
14789     }
14790 
14791     // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
14792     //   -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
14793     // FIXME: This turns two single-precision and one double-precision
14794     // operation into two double-precision operations, which might not be
14795     // interesting for all targets, especially GPUs.
14796     if (N1.getOpcode() == ISD::FP_EXTEND) {
14797       SDValue N10 = N1.getOperand(0);
14798       if (isFusedOp(N10)) {
14799         SDValue N102 = N10.getOperand(2);
14800         if (isContractableFMUL(N102) &&
14801             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14802                                 N10.getValueType())) {
14803           return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
14804                                       N102.getOperand(0), N102.getOperand(1),
14805                                       N0);
14806         }
14807       }
14808     }
14809   }
14810 
14811   return SDValue();
14812 }
14813 
14814 /// Try to perform FMA combining on a given FSUB node.
visitFSUBForFMACombine(SDNode * N)14815 SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
14816   SDValue N0 = N->getOperand(0);
14817   SDValue N1 = N->getOperand(1);
14818   EVT VT = N->getValueType(0);
14819   SDLoc SL(N);
14820 
14821   const TargetOptions &Options = DAG.getTarget().Options;
14822   // Floating-point multiply-add with intermediate rounding.
14823   bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N));
14824 
14825   // Floating-point multiply-add without intermediate rounding.
14826   bool HasFMA =
14827       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
14828       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
14829 
14830   // No valid opcode, do not combine.
14831   if (!HasFMAD && !HasFMA)
14832     return SDValue();
14833 
14834   const SDNodeFlags Flags = N->getFlags();
14835   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
14836                               Options.UnsafeFPMath || HasFMAD);
14837 
14838   // If the subtraction is not contractable, do not combine.
14839   if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
14840     return SDValue();
14841 
14842   if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
14843     return SDValue();
14844 
14845   // Always prefer FMAD to FMA for precision.
14846   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
14847   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
14848   bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
14849 
14850   // Is the node an FMUL and contractable either due to global flags or
14851   // SDNodeFlags.
14852   auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
14853     if (N.getOpcode() != ISD::FMUL)
14854       return false;
14855     return AllowFusionGlobally || N->getFlags().hasAllowContract();
14856   };
14857 
14858   // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
14859   auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
14860     if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
14861       return DAG.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0),
14862                          XY.getOperand(1), DAG.getNode(ISD::FNEG, SL, VT, Z));
14863     }
14864     return SDValue();
14865   };
14866 
14867   // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
14868   // Note: Commutes FSUB operands.
14869   auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
14870     if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
14871       return DAG.getNode(PreferredFusedOpcode, SL, VT,
14872                          DAG.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)),
14873                          YZ.getOperand(1), X);
14874     }
14875     return SDValue();
14876   };
14877 
14878   // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
14879   // prefer to fold the multiply with fewer uses.
14880   if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
14881       (N0->use_size() > N1->use_size())) {
14882     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
14883     if (SDValue V = tryToFoldXSubYZ(N0, N1))
14884       return V;
14885     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
14886     if (SDValue V = tryToFoldXYSubZ(N0, N1))
14887       return V;
14888   } else {
14889     // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
14890     if (SDValue V = tryToFoldXYSubZ(N0, N1))
14891       return V;
14892     // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
14893     if (SDValue V = tryToFoldXSubYZ(N0, N1))
14894       return V;
14895   }
14896 
14897   // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
14898   if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) &&
14899       (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
14900     SDValue N00 = N0.getOperand(0).getOperand(0);
14901     SDValue N01 = N0.getOperand(0).getOperand(1);
14902     return DAG.getNode(PreferredFusedOpcode, SL, VT,
14903                        DAG.getNode(ISD::FNEG, SL, VT, N00), N01,
14904                        DAG.getNode(ISD::FNEG, SL, VT, N1));
14905   }
14906 
14907   // Look through FP_EXTEND nodes to do more combining.
14908 
14909   // fold (fsub (fpext (fmul x, y)), z)
14910   //   -> (fma (fpext x), (fpext y), (fneg z))
14911   if (N0.getOpcode() == ISD::FP_EXTEND) {
14912     SDValue N00 = N0.getOperand(0);
14913     if (isContractableFMUL(N00) &&
14914         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14915                             N00.getValueType())) {
14916       return DAG.getNode(PreferredFusedOpcode, SL, VT,
14917                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
14918                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
14919                          DAG.getNode(ISD::FNEG, SL, VT, N1));
14920     }
14921   }
14922 
14923   // fold (fsub x, (fpext (fmul y, z)))
14924   //   -> (fma (fneg (fpext y)), (fpext z), x)
14925   // Note: Commutes FSUB operands.
14926   if (N1.getOpcode() == ISD::FP_EXTEND) {
14927     SDValue N10 = N1.getOperand(0);
14928     if (isContractableFMUL(N10) &&
14929         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14930                             N10.getValueType())) {
14931       return DAG.getNode(
14932           PreferredFusedOpcode, SL, VT,
14933           DAG.getNode(ISD::FNEG, SL, VT,
14934                       DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))),
14935           DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
14936     }
14937   }
14938 
14939   // fold (fsub (fpext (fneg (fmul, x, y))), z)
14940   //   -> (fneg (fma (fpext x), (fpext y), z))
14941   // Note: This could be removed with appropriate canonicalization of the
14942   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
14943   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
14944   // from implementing the canonicalization in visitFSUB.
14945   if (N0.getOpcode() == ISD::FP_EXTEND) {
14946     SDValue N00 = N0.getOperand(0);
14947     if (N00.getOpcode() == ISD::FNEG) {
14948       SDValue N000 = N00.getOperand(0);
14949       if (isContractableFMUL(N000) &&
14950           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14951                               N00.getValueType())) {
14952         return DAG.getNode(
14953             ISD::FNEG, SL, VT,
14954             DAG.getNode(PreferredFusedOpcode, SL, VT,
14955                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
14956                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
14957                         N1));
14958       }
14959     }
14960   }
14961 
14962   // fold (fsub (fneg (fpext (fmul, x, y))), z)
14963   //   -> (fneg (fma (fpext x)), (fpext y), z)
14964   // Note: This could be removed with appropriate canonicalization of the
14965   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
14966   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
14967   // from implementing the canonicalization in visitFSUB.
14968   if (N0.getOpcode() == ISD::FNEG) {
14969     SDValue N00 = N0.getOperand(0);
14970     if (N00.getOpcode() == ISD::FP_EXTEND) {
14971       SDValue N000 = N00.getOperand(0);
14972       if (isContractableFMUL(N000) &&
14973           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14974                               N000.getValueType())) {
14975         return DAG.getNode(
14976             ISD::FNEG, SL, VT,
14977             DAG.getNode(PreferredFusedOpcode, SL, VT,
14978                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
14979                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
14980                         N1));
14981       }
14982     }
14983   }
14984 
14985   auto isReassociable = [Options](SDNode *N) {
14986     return Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
14987   };
14988 
14989   auto isContractableAndReassociableFMUL = [&isContractableFMUL,
14990                                             &isReassociable](SDValue N) {
14991     return isContractableFMUL(N) && isReassociable(N.getNode());
14992   };
14993 
14994   auto isFusedOp = [&](SDValue N) {
14995     unsigned Opcode = N.getOpcode();
14996     return Opcode == ISD::FMA || Opcode == ISD::FMAD;
14997   };
14998 
14999   // More folding opportunities when target permits.
15000   if (Aggressive && isReassociable(N)) {
15001     bool CanFuse = Options.UnsafeFPMath || N->getFlags().hasAllowContract();
15002     // fold (fsub (fma x, y, (fmul u, v)), z)
15003     //   -> (fma x, y (fma u, v, (fneg z)))
15004     if (CanFuse && isFusedOp(N0) &&
15005         isContractableAndReassociableFMUL(N0.getOperand(2)) &&
15006         N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
15007       return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
15008                          N0.getOperand(1),
15009                          DAG.getNode(PreferredFusedOpcode, SL, VT,
15010                                      N0.getOperand(2).getOperand(0),
15011                                      N0.getOperand(2).getOperand(1),
15012                                      DAG.getNode(ISD::FNEG, SL, VT, N1)));
15013     }
15014 
15015     // fold (fsub x, (fma y, z, (fmul u, v)))
15016     //   -> (fma (fneg y), z, (fma (fneg u), v, x))
15017     if (CanFuse && isFusedOp(N1) &&
15018         isContractableAndReassociableFMUL(N1.getOperand(2)) &&
15019         N1->hasOneUse() && NoSignedZero) {
15020       SDValue N20 = N1.getOperand(2).getOperand(0);
15021       SDValue N21 = N1.getOperand(2).getOperand(1);
15022       return DAG.getNode(
15023           PreferredFusedOpcode, SL, VT,
15024           DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1),
15025           DAG.getNode(PreferredFusedOpcode, SL, VT,
15026                       DAG.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
15027     }
15028 
15029     // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
15030     //   -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
15031     if (isFusedOp(N0) && N0->hasOneUse()) {
15032       SDValue N02 = N0.getOperand(2);
15033       if (N02.getOpcode() == ISD::FP_EXTEND) {
15034         SDValue N020 = N02.getOperand(0);
15035         if (isContractableAndReassociableFMUL(N020) &&
15036             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15037                                 N020.getValueType())) {
15038           return DAG.getNode(
15039               PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
15040               DAG.getNode(
15041                   PreferredFusedOpcode, SL, VT,
15042                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)),
15043                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)),
15044                   DAG.getNode(ISD::FNEG, SL, VT, N1)));
15045         }
15046       }
15047     }
15048 
15049     // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
15050     //   -> (fma (fpext x), (fpext y),
15051     //           (fma (fpext u), (fpext v), (fneg z)))
15052     // FIXME: This turns two single-precision and one double-precision
15053     // operation into two double-precision operations, which might not be
15054     // interesting for all targets, especially GPUs.
15055     if (N0.getOpcode() == ISD::FP_EXTEND) {
15056       SDValue N00 = N0.getOperand(0);
15057       if (isFusedOp(N00)) {
15058         SDValue N002 = N00.getOperand(2);
15059         if (isContractableAndReassociableFMUL(N002) &&
15060             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15061                                 N00.getValueType())) {
15062           return DAG.getNode(
15063               PreferredFusedOpcode, SL, VT,
15064               DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
15065               DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
15066               DAG.getNode(
15067                   PreferredFusedOpcode, SL, VT,
15068                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)),
15069                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)),
15070                   DAG.getNode(ISD::FNEG, SL, VT, N1)));
15071         }
15072       }
15073     }
15074 
15075     // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
15076     //   -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
15077     if (isFusedOp(N1) && N1.getOperand(2).getOpcode() == ISD::FP_EXTEND &&
15078         N1->hasOneUse()) {
15079       SDValue N120 = N1.getOperand(2).getOperand(0);
15080       if (isContractableAndReassociableFMUL(N120) &&
15081           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15082                               N120.getValueType())) {
15083         SDValue N1200 = N120.getOperand(0);
15084         SDValue N1201 = N120.getOperand(1);
15085         return DAG.getNode(
15086             PreferredFusedOpcode, SL, VT,
15087             DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1),
15088             DAG.getNode(PreferredFusedOpcode, SL, VT,
15089                         DAG.getNode(ISD::FNEG, SL, VT,
15090                                     DAG.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
15091                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
15092       }
15093     }
15094 
15095     // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
15096     //   -> (fma (fneg (fpext y)), (fpext z),
15097     //           (fma (fneg (fpext u)), (fpext v), x))
15098     // FIXME: This turns two single-precision and one double-precision
15099     // operation into two double-precision operations, which might not be
15100     // interesting for all targets, especially GPUs.
15101     if (N1.getOpcode() == ISD::FP_EXTEND && isFusedOp(N1.getOperand(0))) {
15102       SDValue CvtSrc = N1.getOperand(0);
15103       SDValue N100 = CvtSrc.getOperand(0);
15104       SDValue N101 = CvtSrc.getOperand(1);
15105       SDValue N102 = CvtSrc.getOperand(2);
15106       if (isContractableAndReassociableFMUL(N102) &&
15107           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15108                               CvtSrc.getValueType())) {
15109         SDValue N1020 = N102.getOperand(0);
15110         SDValue N1021 = N102.getOperand(1);
15111         return DAG.getNode(
15112             PreferredFusedOpcode, SL, VT,
15113             DAG.getNode(ISD::FNEG, SL, VT,
15114                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N100)),
15115             DAG.getNode(ISD::FP_EXTEND, SL, VT, N101),
15116             DAG.getNode(PreferredFusedOpcode, SL, VT,
15117                         DAG.getNode(ISD::FNEG, SL, VT,
15118                                     DAG.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
15119                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
15120       }
15121     }
15122   }
15123 
15124   return SDValue();
15125 }
15126 
15127 /// Try to perform FMA combining on a given FMUL node based on the distributive
15128 /// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
15129 /// subtraction instead of addition).
visitFMULForFMADistributiveCombine(SDNode * N)15130 SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
15131   SDValue N0 = N->getOperand(0);
15132   SDValue N1 = N->getOperand(1);
15133   EVT VT = N->getValueType(0);
15134   SDLoc SL(N);
15135 
15136   assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
15137 
15138   const TargetOptions &Options = DAG.getTarget().Options;
15139 
15140   // The transforms below are incorrect when x == 0 and y == inf, because the
15141   // intermediate multiplication produces a nan.
15142   SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
15143   if (!hasNoInfs(Options, FAdd))
15144     return SDValue();
15145 
15146   // Floating-point multiply-add without intermediate rounding.
15147   bool HasFMA =
15148       isContractableFMUL(Options, SDValue(N, 0)) &&
15149       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
15150       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
15151 
15152   // Floating-point multiply-add with intermediate rounding. This can result
15153   // in a less precise result due to the changed rounding order.
15154   bool HasFMAD = Options.UnsafeFPMath &&
15155                  (LegalOperations && TLI.isFMADLegal(DAG, N));
15156 
15157   // No valid opcode, do not combine.
15158   if (!HasFMAD && !HasFMA)
15159     return SDValue();
15160 
15161   // Always prefer FMAD to FMA for precision.
15162   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
15163   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
15164 
15165   // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
15166   // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
15167   auto FuseFADD = [&](SDValue X, SDValue Y) {
15168     if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
15169       if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
15170         if (C->isExactlyValue(+1.0))
15171           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
15172                              Y);
15173         if (C->isExactlyValue(-1.0))
15174           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
15175                              DAG.getNode(ISD::FNEG, SL, VT, Y));
15176       }
15177     }
15178     return SDValue();
15179   };
15180 
15181   if (SDValue FMA = FuseFADD(N0, N1))
15182     return FMA;
15183   if (SDValue FMA = FuseFADD(N1, N0))
15184     return FMA;
15185 
15186   // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
15187   // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
15188   // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
15189   // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
15190   auto FuseFSUB = [&](SDValue X, SDValue Y) {
15191     if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
15192       if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
15193         if (C0->isExactlyValue(+1.0))
15194           return DAG.getNode(PreferredFusedOpcode, SL, VT,
15195                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
15196                              Y);
15197         if (C0->isExactlyValue(-1.0))
15198           return DAG.getNode(PreferredFusedOpcode, SL, VT,
15199                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
15200                              DAG.getNode(ISD::FNEG, SL, VT, Y));
15201       }
15202       if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
15203         if (C1->isExactlyValue(+1.0))
15204           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
15205                              DAG.getNode(ISD::FNEG, SL, VT, Y));
15206         if (C1->isExactlyValue(-1.0))
15207           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
15208                              Y);
15209       }
15210     }
15211     return SDValue();
15212   };
15213 
15214   if (SDValue FMA = FuseFSUB(N0, N1))
15215     return FMA;
15216   if (SDValue FMA = FuseFSUB(N1, N0))
15217     return FMA;
15218 
15219   return SDValue();
15220 }
15221 
visitFADD(SDNode * N)15222 SDValue DAGCombiner::visitFADD(SDNode *N) {
15223   SDValue N0 = N->getOperand(0);
15224   SDValue N1 = N->getOperand(1);
15225   SDNode *N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N0);
15226   SDNode *N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N1);
15227   EVT VT = N->getValueType(0);
15228   SDLoc DL(N);
15229   const TargetOptions &Options = DAG.getTarget().Options;
15230   SDNodeFlags Flags = N->getFlags();
15231   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15232 
15233   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
15234     return R;
15235 
15236   // fold (fadd c1, c2) -> c1 + c2
15237   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FADD, DL, VT, {N0, N1}))
15238     return C;
15239 
15240   // canonicalize constant to RHS
15241   if (N0CFP && !N1CFP)
15242     return DAG.getNode(ISD::FADD, DL, VT, N1, N0);
15243 
15244   // fold vector ops
15245   if (VT.isVector())
15246     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
15247       return FoldedVOp;
15248 
15249   // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
15250   ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
15251   if (N1C && N1C->isZero())
15252     if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
15253       return N0;
15254 
15255   if (SDValue NewSel = foldBinOpIntoSelect(N))
15256     return NewSel;
15257 
15258   // fold (fadd A, (fneg B)) -> (fsub A, B)
15259   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
15260     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
15261             N1, DAG, LegalOperations, ForCodeSize))
15262       return DAG.getNode(ISD::FSUB, DL, VT, N0, NegN1);
15263 
15264   // fold (fadd (fneg A), B) -> (fsub B, A)
15265   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
15266     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
15267             N0, DAG, LegalOperations, ForCodeSize))
15268       return DAG.getNode(ISD::FSUB, DL, VT, N1, NegN0);
15269 
15270   auto isFMulNegTwo = [](SDValue FMul) {
15271     if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
15272       return false;
15273     auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
15274     return C && C->isExactlyValue(-2.0);
15275   };
15276 
15277   // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
15278   if (isFMulNegTwo(N0)) {
15279     SDValue B = N0.getOperand(0);
15280     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
15281     return DAG.getNode(ISD::FSUB, DL, VT, N1, Add);
15282   }
15283   // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
15284   if (isFMulNegTwo(N1)) {
15285     SDValue B = N1.getOperand(0);
15286     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
15287     return DAG.getNode(ISD::FSUB, DL, VT, N0, Add);
15288   }
15289 
15290   // No FP constant should be created after legalization as Instruction
15291   // Selection pass has a hard time dealing with FP constants.
15292   bool AllowNewConst = (Level < AfterLegalizeDAG);
15293 
15294   // If nnan is enabled, fold lots of things.
15295   if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
15296     // If allowed, fold (fadd (fneg x), x) -> 0.0
15297     if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
15298       return DAG.getConstantFP(0.0, DL, VT);
15299 
15300     // If allowed, fold (fadd x, (fneg x)) -> 0.0
15301     if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
15302       return DAG.getConstantFP(0.0, DL, VT);
15303   }
15304 
15305   // If 'unsafe math' or reassoc and nsz, fold lots of things.
15306   // TODO: break out portions of the transformations below for which Unsafe is
15307   //       considered and which do not require both nsz and reassoc
15308   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
15309        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
15310       AllowNewConst) {
15311     // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
15312     if (N1CFP && N0.getOpcode() == ISD::FADD &&
15313         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
15314       SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1);
15315       return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC);
15316     }
15317 
15318     // We can fold chains of FADD's of the same value into multiplications.
15319     // This transform is not safe in general because we are reducing the number
15320     // of rounding steps.
15321     if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
15322       if (N0.getOpcode() == ISD::FMUL) {
15323         SDNode *CFP00 =
15324             DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
15325         SDNode *CFP01 =
15326             DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1));
15327 
15328         // (fadd (fmul x, c), x) -> (fmul x, c+1)
15329         if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
15330           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
15331                                        DAG.getConstantFP(1.0, DL, VT));
15332           return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP);
15333         }
15334 
15335         // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
15336         if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
15337             N1.getOperand(0) == N1.getOperand(1) &&
15338             N0.getOperand(0) == N1.getOperand(0)) {
15339           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
15340                                        DAG.getConstantFP(2.0, DL, VT));
15341           return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP);
15342         }
15343       }
15344 
15345       if (N1.getOpcode() == ISD::FMUL) {
15346         SDNode *CFP10 =
15347             DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
15348         SDNode *CFP11 =
15349             DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(1));
15350 
15351         // (fadd x, (fmul x, c)) -> (fmul x, c+1)
15352         if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
15353           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
15354                                        DAG.getConstantFP(1.0, DL, VT));
15355           return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP);
15356         }
15357 
15358         // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
15359         if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
15360             N0.getOperand(0) == N0.getOperand(1) &&
15361             N1.getOperand(0) == N0.getOperand(0)) {
15362           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
15363                                        DAG.getConstantFP(2.0, DL, VT));
15364           return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP);
15365         }
15366       }
15367 
15368       if (N0.getOpcode() == ISD::FADD) {
15369         SDNode *CFP00 =
15370             DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
15371         // (fadd (fadd x, x), x) -> (fmul x, 3.0)
15372         if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
15373             (N0.getOperand(0) == N1)) {
15374           return DAG.getNode(ISD::FMUL, DL, VT, N1,
15375                              DAG.getConstantFP(3.0, DL, VT));
15376         }
15377       }
15378 
15379       if (N1.getOpcode() == ISD::FADD) {
15380         SDNode *CFP10 =
15381             DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
15382         // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
15383         if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
15384             N1.getOperand(0) == N0) {
15385           return DAG.getNode(ISD::FMUL, DL, VT, N0,
15386                              DAG.getConstantFP(3.0, DL, VT));
15387         }
15388       }
15389 
15390       // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
15391       if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
15392           N0.getOperand(0) == N0.getOperand(1) &&
15393           N1.getOperand(0) == N1.getOperand(1) &&
15394           N0.getOperand(0) == N1.getOperand(0)) {
15395         return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0),
15396                            DAG.getConstantFP(4.0, DL, VT));
15397       }
15398     }
15399   } // enable-unsafe-fp-math
15400 
15401   // FADD -> FMA combines:
15402   if (SDValue Fused = visitFADDForFMACombine(N)) {
15403     AddToWorklist(Fused.getNode());
15404     return Fused;
15405   }
15406   return SDValue();
15407 }
15408 
visitSTRICT_FADD(SDNode * N)15409 SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
15410   SDValue Chain = N->getOperand(0);
15411   SDValue N0 = N->getOperand(1);
15412   SDValue N1 = N->getOperand(2);
15413   EVT VT = N->getValueType(0);
15414   EVT ChainVT = N->getValueType(1);
15415   SDLoc DL(N);
15416   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15417 
15418   // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
15419   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
15420     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
15421             N1, DAG, LegalOperations, ForCodeSize)) {
15422       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
15423                          {Chain, N0, NegN1});
15424     }
15425 
15426   // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
15427   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
15428     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
15429             N0, DAG, LegalOperations, ForCodeSize)) {
15430       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
15431                          {Chain, N1, NegN0});
15432     }
15433   return SDValue();
15434 }
15435 
visitFSUB(SDNode * N)15436 SDValue DAGCombiner::visitFSUB(SDNode *N) {
15437   SDValue N0 = N->getOperand(0);
15438   SDValue N1 = N->getOperand(1);
15439   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
15440   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
15441   EVT VT = N->getValueType(0);
15442   SDLoc DL(N);
15443   const TargetOptions &Options = DAG.getTarget().Options;
15444   const SDNodeFlags Flags = N->getFlags();
15445   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15446 
15447   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
15448     return R;
15449 
15450   // fold (fsub c1, c2) -> c1-c2
15451   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FSUB, DL, VT, {N0, N1}))
15452     return C;
15453 
15454   // fold vector ops
15455   if (VT.isVector())
15456     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
15457       return FoldedVOp;
15458 
15459   if (SDValue NewSel = foldBinOpIntoSelect(N))
15460     return NewSel;
15461 
15462   // (fsub A, 0) -> A
15463   if (N1CFP && N1CFP->isZero()) {
15464     if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
15465         Flags.hasNoSignedZeros()) {
15466       return N0;
15467     }
15468   }
15469 
15470   if (N0 == N1) {
15471     // (fsub x, x) -> 0.0
15472     if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
15473       return DAG.getConstantFP(0.0f, DL, VT);
15474   }
15475 
15476   // (fsub -0.0, N1) -> -N1
15477   if (N0CFP && N0CFP->isZero()) {
15478     if (N0CFP->isNegative() ||
15479         (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
15480       // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
15481       // flushed to zero, unless all users treat denorms as zero (DAZ).
15482       // FIXME: This transform will change the sign of a NaN and the behavior
15483       // of a signaling NaN. It is only valid when a NoNaN flag is present.
15484       DenormalMode DenormMode = DAG.getDenormalMode(VT);
15485       if (DenormMode == DenormalMode::getIEEE()) {
15486         if (SDValue NegN1 =
15487                 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
15488           return NegN1;
15489         if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
15490           return DAG.getNode(ISD::FNEG, DL, VT, N1);
15491       }
15492     }
15493   }
15494 
15495   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
15496        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
15497       N1.getOpcode() == ISD::FADD) {
15498     // X - (X + Y) -> -Y
15499     if (N0 == N1->getOperand(0))
15500       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1));
15501     // X - (Y + X) -> -Y
15502     if (N0 == N1->getOperand(1))
15503       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0));
15504   }
15505 
15506   // fold (fsub A, (fneg B)) -> (fadd A, B)
15507   if (SDValue NegN1 =
15508           TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
15509     return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1);
15510 
15511   // FSUB -> FMA combines:
15512   if (SDValue Fused = visitFSUBForFMACombine(N)) {
15513     AddToWorklist(Fused.getNode());
15514     return Fused;
15515   }
15516 
15517   return SDValue();
15518 }
15519 
visitFMUL(SDNode * N)15520 SDValue DAGCombiner::visitFMUL(SDNode *N) {
15521   SDValue N0 = N->getOperand(0);
15522   SDValue N1 = N->getOperand(1);
15523   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
15524   EVT VT = N->getValueType(0);
15525   SDLoc DL(N);
15526   const TargetOptions &Options = DAG.getTarget().Options;
15527   const SDNodeFlags Flags = N->getFlags();
15528   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15529 
15530   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
15531     return R;
15532 
15533   // fold (fmul c1, c2) -> c1*c2
15534   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FMUL, DL, VT, {N0, N1}))
15535     return C;
15536 
15537   // canonicalize constant to RHS
15538   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
15539      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
15540     return DAG.getNode(ISD::FMUL, DL, VT, N1, N0);
15541 
15542   // fold vector ops
15543   if (VT.isVector())
15544     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
15545       return FoldedVOp;
15546 
15547   if (SDValue NewSel = foldBinOpIntoSelect(N))
15548     return NewSel;
15549 
15550   if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
15551     // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
15552     if (DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
15553         N0.getOpcode() == ISD::FMUL) {
15554       SDValue N00 = N0.getOperand(0);
15555       SDValue N01 = N0.getOperand(1);
15556       // Avoid an infinite loop by making sure that N00 is not a constant
15557       // (the inner multiply has not been constant folded yet).
15558       if (DAG.isConstantFPBuildVectorOrConstantFP(N01) &&
15559           !DAG.isConstantFPBuildVectorOrConstantFP(N00)) {
15560         SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1);
15561         return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts);
15562       }
15563     }
15564 
15565     // Match a special-case: we convert X * 2.0 into fadd.
15566     // fmul (fadd X, X), C -> fmul X, 2.0 * C
15567     if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
15568         N0.getOperand(0) == N0.getOperand(1)) {
15569       const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
15570       SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
15571       return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
15572     }
15573   }
15574 
15575   // fold (fmul X, 2.0) -> (fadd X, X)
15576   if (N1CFP && N1CFP->isExactlyValue(+2.0))
15577     return DAG.getNode(ISD::FADD, DL, VT, N0, N0);
15578 
15579   // fold (fmul X, -1.0) -> (fsub -0.0, X)
15580   if (N1CFP && N1CFP->isExactlyValue(-1.0)) {
15581     if (!LegalOperations || TLI.isOperationLegal(ISD::FSUB, VT)) {
15582       return DAG.getNode(ISD::FSUB, DL, VT,
15583                          DAG.getConstantFP(-0.0, DL, VT), N0, Flags);
15584     }
15585   }
15586 
15587   // -N0 * -N1 --> N0 * N1
15588   TargetLowering::NegatibleCost CostN0 =
15589       TargetLowering::NegatibleCost::Expensive;
15590   TargetLowering::NegatibleCost CostN1 =
15591       TargetLowering::NegatibleCost::Expensive;
15592   SDValue NegN0 =
15593       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
15594   if (NegN0) {
15595     HandleSDNode NegN0Handle(NegN0);
15596     SDValue NegN1 =
15597         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
15598     if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
15599                   CostN1 == TargetLowering::NegatibleCost::Cheaper))
15600       return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1);
15601   }
15602 
15603   // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
15604   // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
15605   if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
15606       (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
15607       TLI.isOperationLegal(ISD::FABS, VT)) {
15608     SDValue Select = N0, X = N1;
15609     if (Select.getOpcode() != ISD::SELECT)
15610       std::swap(Select, X);
15611 
15612     SDValue Cond = Select.getOperand(0);
15613     auto TrueOpnd  = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
15614     auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
15615 
15616     if (TrueOpnd && FalseOpnd &&
15617         Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
15618         isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
15619         cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
15620       ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
15621       switch (CC) {
15622       default: break;
15623       case ISD::SETOLT:
15624       case ISD::SETULT:
15625       case ISD::SETOLE:
15626       case ISD::SETULE:
15627       case ISD::SETLT:
15628       case ISD::SETLE:
15629         std::swap(TrueOpnd, FalseOpnd);
15630         [[fallthrough]];
15631       case ISD::SETOGT:
15632       case ISD::SETUGT:
15633       case ISD::SETOGE:
15634       case ISD::SETUGE:
15635       case ISD::SETGT:
15636       case ISD::SETGE:
15637         if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
15638             TLI.isOperationLegal(ISD::FNEG, VT))
15639           return DAG.getNode(ISD::FNEG, DL, VT,
15640                    DAG.getNode(ISD::FABS, DL, VT, X));
15641         if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
15642           return DAG.getNode(ISD::FABS, DL, VT, X);
15643 
15644         break;
15645       }
15646     }
15647   }
15648 
15649   // FMUL -> FMA combines:
15650   if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
15651     AddToWorklist(Fused.getNode());
15652     return Fused;
15653   }
15654 
15655   return SDValue();
15656 }
15657 
visitFMA(SDNode * N)15658 SDValue DAGCombiner::visitFMA(SDNode *N) {
15659   SDValue N0 = N->getOperand(0);
15660   SDValue N1 = N->getOperand(1);
15661   SDValue N2 = N->getOperand(2);
15662   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
15663   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
15664   EVT VT = N->getValueType(0);
15665   SDLoc DL(N);
15666   const TargetOptions &Options = DAG.getTarget().Options;
15667   // FMA nodes have flags that propagate to the created nodes.
15668   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15669 
15670   bool CanReassociate =
15671       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
15672 
15673   // Constant fold FMA.
15674   if (isa<ConstantFPSDNode>(N0) &&
15675       isa<ConstantFPSDNode>(N1) &&
15676       isa<ConstantFPSDNode>(N2)) {
15677     return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2);
15678   }
15679 
15680   // (-N0 * -N1) + N2 --> (N0 * N1) + N2
15681   TargetLowering::NegatibleCost CostN0 =
15682       TargetLowering::NegatibleCost::Expensive;
15683   TargetLowering::NegatibleCost CostN1 =
15684       TargetLowering::NegatibleCost::Expensive;
15685   SDValue NegN0 =
15686       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
15687   if (NegN0) {
15688     HandleSDNode NegN0Handle(NegN0);
15689     SDValue NegN1 =
15690         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
15691     if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
15692                   CostN1 == TargetLowering::NegatibleCost::Cheaper))
15693       return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
15694   }
15695 
15696   // FIXME: use fast math flags instead of Options.UnsafeFPMath
15697   if (Options.UnsafeFPMath) {
15698     if (N0CFP && N0CFP->isZero())
15699       return N2;
15700     if (N1CFP && N1CFP->isZero())
15701       return N2;
15702   }
15703 
15704   if (N0CFP && N0CFP->isExactlyValue(1.0))
15705     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
15706   if (N1CFP && N1CFP->isExactlyValue(1.0))
15707     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
15708 
15709   // Canonicalize (fma c, x, y) -> (fma x, c, y)
15710   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
15711      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
15712     return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
15713 
15714   if (CanReassociate) {
15715     // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
15716     if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) &&
15717         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
15718         DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
15719       return DAG.getNode(ISD::FMUL, DL, VT, N0,
15720                          DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
15721     }
15722 
15723     // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
15724     if (N0.getOpcode() == ISD::FMUL &&
15725         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
15726         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
15727       return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
15728                          DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)),
15729                          N2);
15730     }
15731   }
15732 
15733   // (fma x, -1, y) -> (fadd (fneg x), y)
15734   if (N1CFP) {
15735     if (N1CFP->isExactlyValue(1.0))
15736       return DAG.getNode(ISD::FADD, DL, VT, N0, N2);
15737 
15738     if (N1CFP->isExactlyValue(-1.0) &&
15739         (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
15740       SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0);
15741       AddToWorklist(RHSNeg.getNode());
15742       return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
15743     }
15744 
15745     // fma (fneg x), K, y -> fma x -K, y
15746     if (N0.getOpcode() == ISD::FNEG &&
15747         (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
15748          (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT,
15749                                               ForCodeSize)))) {
15750       return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
15751                          DAG.getNode(ISD::FNEG, DL, VT, N1), N2);
15752     }
15753   }
15754 
15755   if (CanReassociate) {
15756     // (fma x, c, x) -> (fmul x, (c+1))
15757     if (N1CFP && N0 == N2) {
15758       return DAG.getNode(
15759           ISD::FMUL, DL, VT, N0,
15760           DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(1.0, DL, VT)));
15761     }
15762 
15763     // (fma x, c, (fneg x)) -> (fmul x, (c-1))
15764     if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
15765       return DAG.getNode(
15766           ISD::FMUL, DL, VT, N0,
15767           DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(-1.0, DL, VT)));
15768     }
15769   }
15770 
15771   // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
15772   // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
15773   if (!TLI.isFNegFree(VT))
15774     if (SDValue Neg = TLI.getCheaperNegatedExpression(
15775             SDValue(N, 0), DAG, LegalOperations, ForCodeSize))
15776       return DAG.getNode(ISD::FNEG, DL, VT, Neg);
15777   return SDValue();
15778 }
15779 
15780 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
15781 // reciprocal.
15782 // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
15783 // Notice that this is not always beneficial. One reason is different targets
15784 // may have different costs for FDIV and FMUL, so sometimes the cost of two
15785 // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
15786 // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
combineRepeatedFPDivisors(SDNode * N)15787 SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
15788   // TODO: Limit this transform based on optsize/minsize - it always creates at
15789   //       least 1 extra instruction. But the perf win may be substantial enough
15790   //       that only minsize should restrict this.
15791   bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
15792   const SDNodeFlags Flags = N->getFlags();
15793   if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
15794     return SDValue();
15795 
15796   // Skip if current node is a reciprocal/fneg-reciprocal.
15797   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
15798   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, /* AllowUndefs */ true);
15799   if (N0CFP && (N0CFP->isExactlyValue(1.0) || N0CFP->isExactlyValue(-1.0)))
15800     return SDValue();
15801 
15802   // Exit early if the target does not want this transform or if there can't
15803   // possibly be enough uses of the divisor to make the transform worthwhile.
15804   unsigned MinUses = TLI.combineRepeatedFPDivisors();
15805 
15806   // For splat vectors, scale the number of uses by the splat factor. If we can
15807   // convert the division into a scalar op, that will likely be much faster.
15808   unsigned NumElts = 1;
15809   EVT VT = N->getValueType(0);
15810   if (VT.isVector() && DAG.isSplatValue(N1))
15811     NumElts = VT.getVectorMinNumElements();
15812 
15813   if (!MinUses || (N1->use_size() * NumElts) < MinUses)
15814     return SDValue();
15815 
15816   // Find all FDIV users of the same divisor.
15817   // Use a set because duplicates may be present in the user list.
15818   SetVector<SDNode *> Users;
15819   for (auto *U : N1->uses()) {
15820     if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
15821       // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
15822       if (U->getOperand(1).getOpcode() == ISD::FSQRT &&
15823           U->getOperand(0) == U->getOperand(1).getOperand(0) &&
15824           U->getFlags().hasAllowReassociation() &&
15825           U->getFlags().hasNoSignedZeros())
15826         continue;
15827 
15828       // This division is eligible for optimization only if global unsafe math
15829       // is enabled or if this division allows reciprocal formation.
15830       if (UnsafeMath || U->getFlags().hasAllowReciprocal())
15831         Users.insert(U);
15832     }
15833   }
15834 
15835   // Now that we have the actual number of divisor uses, make sure it meets
15836   // the minimum threshold specified by the target.
15837   if ((Users.size() * NumElts) < MinUses)
15838     return SDValue();
15839 
15840   SDLoc DL(N);
15841   SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
15842   SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags);
15843 
15844   // Dividend / Divisor -> Dividend * Reciprocal
15845   for (auto *U : Users) {
15846     SDValue Dividend = U->getOperand(0);
15847     if (Dividend != FPOne) {
15848       SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
15849                                     Reciprocal, Flags);
15850       CombineTo(U, NewNode);
15851     } else if (U != Reciprocal.getNode()) {
15852       // In the absence of fast-math-flags, this user node is always the
15853       // same node as Reciprocal, but with FMF they may be different nodes.
15854       CombineTo(U, Reciprocal);
15855     }
15856   }
15857   return SDValue(N, 0);  // N was replaced.
15858 }
15859 
visitFDIV(SDNode * N)15860 SDValue DAGCombiner::visitFDIV(SDNode *N) {
15861   SDValue N0 = N->getOperand(0);
15862   SDValue N1 = N->getOperand(1);
15863   EVT VT = N->getValueType(0);
15864   SDLoc DL(N);
15865   const TargetOptions &Options = DAG.getTarget().Options;
15866   SDNodeFlags Flags = N->getFlags();
15867   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15868 
15869   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
15870     return R;
15871 
15872   // fold (fdiv c1, c2) -> c1/c2
15873   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FDIV, DL, VT, {N0, N1}))
15874     return C;
15875 
15876   // fold vector ops
15877   if (VT.isVector())
15878     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
15879       return FoldedVOp;
15880 
15881   if (SDValue NewSel = foldBinOpIntoSelect(N))
15882     return NewSel;
15883 
15884   if (SDValue V = combineRepeatedFPDivisors(N))
15885     return V;
15886 
15887   if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
15888     // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
15889     if (auto *N1CFP = dyn_cast<ConstantFPSDNode>(N1)) {
15890       // Compute the reciprocal 1.0 / c2.
15891       const APFloat &N1APF = N1CFP->getValueAPF();
15892       APFloat Recip(N1APF.getSemantics(), 1); // 1.0
15893       APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
15894       // Only do the transform if the reciprocal is a legal fp immediate that
15895       // isn't too nasty (eg NaN, denormal, ...).
15896       if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty
15897           (!LegalOperations ||
15898            // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
15899            // backend)... we should handle this gracefully after Legalize.
15900            // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
15901            TLI.isOperationLegal(ISD::ConstantFP, VT) ||
15902            TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
15903         return DAG.getNode(ISD::FMUL, DL, VT, N0,
15904                            DAG.getConstantFP(Recip, DL, VT));
15905     }
15906 
15907     // If this FDIV is part of a reciprocal square root, it may be folded
15908     // into a target-specific square root estimate instruction.
15909     if (N1.getOpcode() == ISD::FSQRT) {
15910       if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags))
15911         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
15912     } else if (N1.getOpcode() == ISD::FP_EXTEND &&
15913                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
15914       if (SDValue RV =
15915               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
15916         RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
15917         AddToWorklist(RV.getNode());
15918         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
15919       }
15920     } else if (N1.getOpcode() == ISD::FP_ROUND &&
15921                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
15922       if (SDValue RV =
15923               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
15924         RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
15925         AddToWorklist(RV.getNode());
15926         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
15927       }
15928     } else if (N1.getOpcode() == ISD::FMUL) {
15929       // Look through an FMUL. Even though this won't remove the FDIV directly,
15930       // it's still worthwhile to get rid of the FSQRT if possible.
15931       SDValue Sqrt, Y;
15932       if (N1.getOperand(0).getOpcode() == ISD::FSQRT) {
15933         Sqrt = N1.getOperand(0);
15934         Y = N1.getOperand(1);
15935       } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) {
15936         Sqrt = N1.getOperand(1);
15937         Y = N1.getOperand(0);
15938       }
15939       if (Sqrt.getNode()) {
15940         // If the other multiply operand is known positive, pull it into the
15941         // sqrt. That will eliminate the division if we convert to an estimate.
15942         if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
15943             N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
15944           SDValue A;
15945           if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
15946             A = Y.getOperand(0);
15947           else if (Y == Sqrt.getOperand(0))
15948             A = Y;
15949           if (A) {
15950             // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
15951             // X / (A * sqrt(A))       --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
15952             SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A);
15953             SDValue AAZ =
15954                 DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0));
15955             if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
15956               return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt);
15957 
15958             // Estimate creation failed. Clean up speculatively created nodes.
15959             recursivelyDeleteUnusedNodes(AAZ.getNode());
15960           }
15961         }
15962 
15963         // We found a FSQRT, so try to make this fold:
15964         // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
15965         if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) {
15966           SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y);
15967           AddToWorklist(Div.getNode());
15968           return DAG.getNode(ISD::FMUL, DL, VT, N0, Div);
15969         }
15970       }
15971     }
15972 
15973     // Fold into a reciprocal estimate and multiply instead of a real divide.
15974     if (Options.NoInfsFPMath || Flags.hasNoInfs())
15975       if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
15976         return RV;
15977   }
15978 
15979   // Fold X/Sqrt(X) -> Sqrt(X)
15980   if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
15981       (Options.UnsafeFPMath || Flags.hasAllowReassociation()))
15982     if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(0))
15983       return N1;
15984 
15985   // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
15986   TargetLowering::NegatibleCost CostN0 =
15987       TargetLowering::NegatibleCost::Expensive;
15988   TargetLowering::NegatibleCost CostN1 =
15989       TargetLowering::NegatibleCost::Expensive;
15990   SDValue NegN0 =
15991       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
15992   if (NegN0) {
15993     HandleSDNode NegN0Handle(NegN0);
15994     SDValue NegN1 =
15995         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
15996     if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
15997                   CostN1 == TargetLowering::NegatibleCost::Cheaper))
15998       return DAG.getNode(ISD::FDIV, SDLoc(N), VT, NegN0, NegN1);
15999   }
16000 
16001   return SDValue();
16002 }
16003 
visitFREM(SDNode * N)16004 SDValue DAGCombiner::visitFREM(SDNode *N) {
16005   SDValue N0 = N->getOperand(0);
16006   SDValue N1 = N->getOperand(1);
16007   EVT VT = N->getValueType(0);
16008   SDNodeFlags Flags = N->getFlags();
16009   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16010 
16011   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
16012     return R;
16013 
16014   // fold (frem c1, c2) -> fmod(c1,c2)
16015   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FREM, SDLoc(N), VT, {N0, N1}))
16016     return C;
16017 
16018   if (SDValue NewSel = foldBinOpIntoSelect(N))
16019     return NewSel;
16020 
16021   return SDValue();
16022 }
16023 
visitFSQRT(SDNode * N)16024 SDValue DAGCombiner::visitFSQRT(SDNode *N) {
16025   SDNodeFlags Flags = N->getFlags();
16026   const TargetOptions &Options = DAG.getTarget().Options;
16027 
16028   // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
16029   // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
16030   if (!Flags.hasApproximateFuncs() ||
16031       (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
16032     return SDValue();
16033 
16034   SDValue N0 = N->getOperand(0);
16035   if (TLI.isFsqrtCheap(N0, DAG))
16036     return SDValue();
16037 
16038   // FSQRT nodes have flags that propagate to the created nodes.
16039   // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
16040   //       transform the fdiv, we may produce a sub-optimal estimate sequence
16041   //       because the reciprocal calculation may not have to filter out a
16042   //       0.0 input.
16043   return buildSqrtEstimate(N0, Flags);
16044 }
16045 
16046 /// copysign(x, fp_extend(y)) -> copysign(x, y)
16047 /// copysign(x, fp_round(y)) -> copysign(x, y)
CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode * N)16048 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
16049   SDValue N1 = N->getOperand(1);
16050   if ((N1.getOpcode() == ISD::FP_EXTEND ||
16051        N1.getOpcode() == ISD::FP_ROUND)) {
16052     EVT N1VT = N1->getValueType(0);
16053     EVT N1Op0VT = N1->getOperand(0).getValueType();
16054 
16055     // Always fold no-op FP casts.
16056     if (N1VT == N1Op0VT)
16057       return true;
16058 
16059     // Do not optimize out type conversion of f128 type yet.
16060     // For some targets like x86_64, configuration is changed to keep one f128
16061     // value in one SSE register, but instruction selection cannot handle
16062     // FCOPYSIGN on SSE registers yet.
16063     if (N1Op0VT == MVT::f128)
16064       return false;
16065 
16066     return !N1Op0VT.isVector() || EnableVectorFCopySignExtendRound;
16067   }
16068   return false;
16069 }
16070 
visitFCOPYSIGN(SDNode * N)16071 SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
16072   SDValue N0 = N->getOperand(0);
16073   SDValue N1 = N->getOperand(1);
16074   EVT VT = N->getValueType(0);
16075 
16076   // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
16077   if (SDValue C =
16078           DAG.FoldConstantArithmetic(ISD::FCOPYSIGN, SDLoc(N), VT, {N0, N1}))
16079     return C;
16080 
16081   if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N->getOperand(1))) {
16082     const APFloat &V = N1C->getValueAPF();
16083     // copysign(x, c1) -> fabs(x)       iff ispos(c1)
16084     // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
16085     if (!V.isNegative()) {
16086       if (!LegalOperations || TLI.isOperationLegal(ISD::FABS, VT))
16087         return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
16088     } else {
16089       if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
16090         return DAG.getNode(ISD::FNEG, SDLoc(N), VT,
16091                            DAG.getNode(ISD::FABS, SDLoc(N0), VT, N0));
16092     }
16093   }
16094 
16095   // copysign(fabs(x), y) -> copysign(x, y)
16096   // copysign(fneg(x), y) -> copysign(x, y)
16097   // copysign(copysign(x,z), y) -> copysign(x, y)
16098   if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
16099       N0.getOpcode() == ISD::FCOPYSIGN)
16100     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0.getOperand(0), N1);
16101 
16102   // copysign(x, abs(y)) -> abs(x)
16103   if (N1.getOpcode() == ISD::FABS)
16104     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
16105 
16106   // copysign(x, copysign(y,z)) -> copysign(x, z)
16107   if (N1.getOpcode() == ISD::FCOPYSIGN)
16108     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(1));
16109 
16110   // copysign(x, fp_extend(y)) -> copysign(x, y)
16111   // copysign(x, fp_round(y)) -> copysign(x, y)
16112   if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
16113     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(0));
16114 
16115   return SDValue();
16116 }
16117 
visitFPOW(SDNode * N)16118 SDValue DAGCombiner::visitFPOW(SDNode *N) {
16119   ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
16120   if (!ExponentC)
16121     return SDValue();
16122   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16123 
16124   // Try to convert x ** (1/3) into cube root.
16125   // TODO: Handle the various flavors of long double.
16126   // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
16127   //       Some range near 1/3 should be fine.
16128   EVT VT = N->getValueType(0);
16129   if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
16130       (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
16131     // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
16132     // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
16133     // pow(-val, 1/3) =  nan; cbrt(-val) = -num.
16134     // For regular numbers, rounding may cause the results to differ.
16135     // Therefore, we require { nsz ninf nnan afn } for this transform.
16136     // TODO: We could select out the special cases if we don't have nsz/ninf.
16137     SDNodeFlags Flags = N->getFlags();
16138     if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
16139         !Flags.hasApproximateFuncs())
16140       return SDValue();
16141 
16142     // Do not create a cbrt() libcall if the target does not have it, and do not
16143     // turn a pow that has lowering support into a cbrt() libcall.
16144     if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
16145         (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
16146          DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
16147       return SDValue();
16148 
16149     return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0));
16150   }
16151 
16152   // Try to convert x ** (1/4) and x ** (3/4) into square roots.
16153   // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
16154   // TODO: This could be extended (using a target hook) to handle smaller
16155   // power-of-2 fractional exponents.
16156   bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25);
16157   bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75);
16158   if (ExponentIs025 || ExponentIs075) {
16159     // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
16160     // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) =  NaN.
16161     // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
16162     // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) =  NaN.
16163     // For regular numbers, rounding may cause the results to differ.
16164     // Therefore, we require { nsz ninf afn } for this transform.
16165     // TODO: We could select out the special cases if we don't have nsz/ninf.
16166     SDNodeFlags Flags = N->getFlags();
16167 
16168     // We only need no signed zeros for the 0.25 case.
16169     if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
16170         !Flags.hasApproximateFuncs())
16171       return SDValue();
16172 
16173     // Don't double the number of libcalls. We are trying to inline fast code.
16174     if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT))
16175       return SDValue();
16176 
16177     // Assume that libcalls are the smallest code.
16178     // TODO: This restriction should probably be lifted for vectors.
16179     if (ForCodeSize)
16180       return SDValue();
16181 
16182     // pow(X, 0.25) --> sqrt(sqrt(X))
16183     SDLoc DL(N);
16184     SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0));
16185     SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt);
16186     if (ExponentIs025)
16187       return SqrtSqrt;
16188     // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
16189     return DAG.getNode(ISD::FMUL, DL, VT, Sqrt, SqrtSqrt);
16190   }
16191 
16192   return SDValue();
16193 }
16194 
foldFPToIntToFP(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)16195 static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
16196                                const TargetLowering &TLI) {
16197   // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
16198   // replacing casts with a libcall. We also must be allowed to ignore -0.0
16199   // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
16200   // conversions would return +0.0.
16201   // FIXME: We should be able to use node-level FMF here.
16202   // TODO: If strict math, should we use FABS (+ range check for signed cast)?
16203   EVT VT = N->getValueType(0);
16204   if (!TLI.isOperationLegal(ISD::FTRUNC, VT) ||
16205       !DAG.getTarget().Options.NoSignedZerosFPMath)
16206     return SDValue();
16207 
16208   // fptosi/fptoui round towards zero, so converting from FP to integer and
16209   // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
16210   SDValue N0 = N->getOperand(0);
16211   if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
16212       N0.getOperand(0).getValueType() == VT)
16213     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
16214 
16215   if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
16216       N0.getOperand(0).getValueType() == VT)
16217     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
16218 
16219   return SDValue();
16220 }
16221 
visitSINT_TO_FP(SDNode * N)16222 SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
16223   SDValue N0 = N->getOperand(0);
16224   EVT VT = N->getValueType(0);
16225   EVT OpVT = N0.getValueType();
16226 
16227   // [us]itofp(undef) = 0, because the result value is bounded.
16228   if (N0.isUndef())
16229     return DAG.getConstantFP(0.0, SDLoc(N), VT);
16230 
16231   // fold (sint_to_fp c1) -> c1fp
16232   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
16233       // ...but only if the target supports immediate floating-point values
16234       (!LegalOperations ||
16235        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
16236     return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
16237 
16238   // If the input is a legal type, and SINT_TO_FP is not legal on this target,
16239   // but UINT_TO_FP is legal on this target, try to convert.
16240   if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
16241       hasOperation(ISD::UINT_TO_FP, OpVT)) {
16242     // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
16243     if (DAG.SignBitIsZero(N0))
16244       return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
16245   }
16246 
16247   // The next optimizations are desirable only if SELECT_CC can be lowered.
16248   // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
16249   if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
16250       !VT.isVector() &&
16251       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
16252     SDLoc DL(N);
16253     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(-1.0, DL, VT),
16254                          DAG.getConstantFP(0.0, DL, VT));
16255   }
16256 
16257   // fold (sint_to_fp (zext (setcc x, y, cc))) ->
16258   //      (select (setcc x, y, cc), 1.0, 0.0)
16259   if (N0.getOpcode() == ISD::ZERO_EXTEND &&
16260       N0.getOperand(0).getOpcode() == ISD::SETCC && !VT.isVector() &&
16261       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
16262     SDLoc DL(N);
16263     return DAG.getSelect(DL, VT, N0.getOperand(0),
16264                          DAG.getConstantFP(1.0, DL, VT),
16265                          DAG.getConstantFP(0.0, DL, VT));
16266   }
16267 
16268   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
16269     return FTrunc;
16270 
16271   return SDValue();
16272 }
16273 
visitUINT_TO_FP(SDNode * N)16274 SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
16275   SDValue N0 = N->getOperand(0);
16276   EVT VT = N->getValueType(0);
16277   EVT OpVT = N0.getValueType();
16278 
16279   // [us]itofp(undef) = 0, because the result value is bounded.
16280   if (N0.isUndef())
16281     return DAG.getConstantFP(0.0, SDLoc(N), VT);
16282 
16283   // fold (uint_to_fp c1) -> c1fp
16284   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
16285       // ...but only if the target supports immediate floating-point values
16286       (!LegalOperations ||
16287        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
16288     return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
16289 
16290   // If the input is a legal type, and UINT_TO_FP is not legal on this target,
16291   // but SINT_TO_FP is legal on this target, try to convert.
16292   if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
16293       hasOperation(ISD::SINT_TO_FP, OpVT)) {
16294     // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
16295     if (DAG.SignBitIsZero(N0))
16296       return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
16297   }
16298 
16299   // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
16300   if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
16301       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
16302     SDLoc DL(N);
16303     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(1.0, DL, VT),
16304                          DAG.getConstantFP(0.0, DL, VT));
16305   }
16306 
16307   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
16308     return FTrunc;
16309 
16310   return SDValue();
16311 }
16312 
16313 // Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
FoldIntToFPToInt(SDNode * N,SelectionDAG & DAG)16314 static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
16315   SDValue N0 = N->getOperand(0);
16316   EVT VT = N->getValueType(0);
16317 
16318   if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
16319     return SDValue();
16320 
16321   SDValue Src = N0.getOperand(0);
16322   EVT SrcVT = Src.getValueType();
16323   bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
16324   bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
16325 
16326   // We can safely assume the conversion won't overflow the output range,
16327   // because (for example) (uint8_t)18293.f is undefined behavior.
16328 
16329   // Since we can assume the conversion won't overflow, our decision as to
16330   // whether the input will fit in the float should depend on the minimum
16331   // of the input range and output range.
16332 
16333   // This means this is also safe for a signed input and unsigned output, since
16334   // a negative input would lead to undefined behavior.
16335   unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
16336   unsigned OutputSize = (int)VT.getScalarSizeInBits();
16337   unsigned ActualSize = std::min(InputSize, OutputSize);
16338   const fltSemantics &sem = DAG.EVTToAPFloatSemantics(N0.getValueType());
16339 
16340   // We can only fold away the float conversion if the input range can be
16341   // represented exactly in the float range.
16342   if (APFloat::semanticsPrecision(sem) >= ActualSize) {
16343     if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
16344       unsigned ExtOp = IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND
16345                                                        : ISD::ZERO_EXTEND;
16346       return DAG.getNode(ExtOp, SDLoc(N), VT, Src);
16347     }
16348     if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
16349       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Src);
16350     return DAG.getBitcast(VT, Src);
16351   }
16352   return SDValue();
16353 }
16354 
visitFP_TO_SINT(SDNode * N)16355 SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
16356   SDValue N0 = N->getOperand(0);
16357   EVT VT = N->getValueType(0);
16358 
16359   // fold (fp_to_sint undef) -> undef
16360   if (N0.isUndef())
16361     return DAG.getUNDEF(VT);
16362 
16363   // fold (fp_to_sint c1fp) -> c1
16364   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16365     return DAG.getNode(ISD::FP_TO_SINT, SDLoc(N), VT, N0);
16366 
16367   return FoldIntToFPToInt(N, DAG);
16368 }
16369 
visitFP_TO_UINT(SDNode * N)16370 SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
16371   SDValue N0 = N->getOperand(0);
16372   EVT VT = N->getValueType(0);
16373 
16374   // fold (fp_to_uint undef) -> undef
16375   if (N0.isUndef())
16376     return DAG.getUNDEF(VT);
16377 
16378   // fold (fp_to_uint c1fp) -> c1
16379   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16380     return DAG.getNode(ISD::FP_TO_UINT, SDLoc(N), VT, N0);
16381 
16382   return FoldIntToFPToInt(N, DAG);
16383 }
16384 
visitFP_ROUND(SDNode * N)16385 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
16386   SDValue N0 = N->getOperand(0);
16387   SDValue N1 = N->getOperand(1);
16388   EVT VT = N->getValueType(0);
16389 
16390   // fold (fp_round c1fp) -> c1fp
16391   if (SDValue C =
16392           DAG.FoldConstantArithmetic(ISD::FP_ROUND, SDLoc(N), VT, {N0, N1}))
16393     return C;
16394 
16395   // fold (fp_round (fp_extend x)) -> x
16396   if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(0).getValueType())
16397     return N0.getOperand(0);
16398 
16399   // fold (fp_round (fp_round x)) -> (fp_round x)
16400   if (N0.getOpcode() == ISD::FP_ROUND) {
16401     const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
16402     const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
16403 
16404     // Skip this folding if it results in an fp_round from f80 to f16.
16405     //
16406     // f80 to f16 always generates an expensive (and as yet, unimplemented)
16407     // libcall to __truncxfhf2 instead of selecting native f16 conversion
16408     // instructions from f32 or f64.  Moreover, the first (value-preserving)
16409     // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
16410     // x86.
16411     if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
16412       return SDValue();
16413 
16414     // If the first fp_round isn't a value preserving truncation, it might
16415     // introduce a tie in the second fp_round, that wouldn't occur in the
16416     // single-step fp_round we want to fold to.
16417     // In other words, double rounding isn't the same as rounding.
16418     // Also, this is a value preserving truncation iff both fp_round's are.
16419     if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc) {
16420       SDLoc DL(N);
16421       return DAG.getNode(
16422           ISD::FP_ROUND, DL, VT, N0.getOperand(0),
16423           DAG.getIntPtrConstant(NIsTrunc && N0IsTrunc, DL, /*isTarget=*/true));
16424     }
16425   }
16426 
16427   // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
16428   if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse()) {
16429     SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT,
16430                               N0.getOperand(0), N1);
16431     AddToWorklist(Tmp.getNode());
16432     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT,
16433                        Tmp, N0.getOperand(1));
16434   }
16435 
16436   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
16437     return NewVSel;
16438 
16439   return SDValue();
16440 }
16441 
visitFP_EXTEND(SDNode * N)16442 SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
16443   SDValue N0 = N->getOperand(0);
16444   EVT VT = N->getValueType(0);
16445 
16446   if (VT.isVector())
16447     if (SDValue FoldedVOp = SimplifyVCastOp(N, SDLoc(N)))
16448       return FoldedVOp;
16449 
16450   // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
16451   if (N->hasOneUse() &&
16452       N->use_begin()->getOpcode() == ISD::FP_ROUND)
16453     return SDValue();
16454 
16455   // fold (fp_extend c1fp) -> c1fp
16456   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16457     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, N0);
16458 
16459   // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
16460   if (N0.getOpcode() == ISD::FP16_TO_FP &&
16461       TLI.getOperationAction(ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
16462     return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), VT, N0.getOperand(0));
16463 
16464   // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
16465   // value of X.
16466   if (N0.getOpcode() == ISD::FP_ROUND
16467       && N0.getConstantOperandVal(1) == 1) {
16468     SDValue In = N0.getOperand(0);
16469     if (In.getValueType() == VT) return In;
16470     if (VT.bitsLT(In.getValueType()))
16471       return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT,
16472                          In, N0.getOperand(1));
16473     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, In);
16474   }
16475 
16476   // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
16477   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
16478       TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) {
16479     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
16480     SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
16481                                      LN0->getChain(),
16482                                      LN0->getBasePtr(), N0.getValueType(),
16483                                      LN0->getMemOperand());
16484     CombineTo(N, ExtLoad);
16485     CombineTo(
16486         N0.getNode(),
16487         DAG.getNode(ISD::FP_ROUND, SDLoc(N0), N0.getValueType(), ExtLoad,
16488                     DAG.getIntPtrConstant(1, SDLoc(N0), /*isTarget=*/true)),
16489         ExtLoad.getValue(1));
16490     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
16491   }
16492 
16493   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
16494     return NewVSel;
16495 
16496   return SDValue();
16497 }
16498 
visitFCEIL(SDNode * N)16499 SDValue DAGCombiner::visitFCEIL(SDNode *N) {
16500   SDValue N0 = N->getOperand(0);
16501   EVT VT = N->getValueType(0);
16502 
16503   // fold (fceil c1) -> fceil(c1)
16504   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16505     return DAG.getNode(ISD::FCEIL, SDLoc(N), VT, N0);
16506 
16507   return SDValue();
16508 }
16509 
visitFTRUNC(SDNode * N)16510 SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
16511   SDValue N0 = N->getOperand(0);
16512   EVT VT = N->getValueType(0);
16513 
16514   // fold (ftrunc c1) -> ftrunc(c1)
16515   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16516     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0);
16517 
16518   // fold ftrunc (known rounded int x) -> x
16519   // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
16520   // likely to be generated to extract integer from a rounded floating value.
16521   switch (N0.getOpcode()) {
16522   default: break;
16523   case ISD::FRINT:
16524   case ISD::FTRUNC:
16525   case ISD::FNEARBYINT:
16526   case ISD::FFLOOR:
16527   case ISD::FCEIL:
16528     return N0;
16529   }
16530 
16531   return SDValue();
16532 }
16533 
visitFFLOOR(SDNode * N)16534 SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
16535   SDValue N0 = N->getOperand(0);
16536   EVT VT = N->getValueType(0);
16537 
16538   // fold (ffloor c1) -> ffloor(c1)
16539   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16540     return DAG.getNode(ISD::FFLOOR, SDLoc(N), VT, N0);
16541 
16542   return SDValue();
16543 }
16544 
visitFNEG(SDNode * N)16545 SDValue DAGCombiner::visitFNEG(SDNode *N) {
16546   SDValue N0 = N->getOperand(0);
16547   EVT VT = N->getValueType(0);
16548   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16549 
16550   // Constant fold FNEG.
16551   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16552     return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0);
16553 
16554   if (SDValue NegN0 =
16555           TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize))
16556     return NegN0;
16557 
16558   // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
16559   // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
16560   // know it was called from a context with a nsz flag if the input fsub does
16561   // not.
16562   if (N0.getOpcode() == ISD::FSUB &&
16563       (DAG.getTarget().Options.NoSignedZerosFPMath ||
16564        N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
16565     return DAG.getNode(ISD::FSUB, SDLoc(N), VT, N0.getOperand(1),
16566                        N0.getOperand(0));
16567   }
16568 
16569   if (SDValue Cast = foldSignChangeInBitcast(N))
16570     return Cast;
16571 
16572   return SDValue();
16573 }
16574 
visitFMinMax(SDNode * N)16575 SDValue DAGCombiner::visitFMinMax(SDNode *N) {
16576   SDValue N0 = N->getOperand(0);
16577   SDValue N1 = N->getOperand(1);
16578   EVT VT = N->getValueType(0);
16579   const SDNodeFlags Flags = N->getFlags();
16580   unsigned Opc = N->getOpcode();
16581   bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
16582   bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
16583   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16584 
16585   // Constant fold.
16586   if (SDValue C = DAG.FoldConstantArithmetic(Opc, SDLoc(N), VT, {N0, N1}))
16587     return C;
16588 
16589   // Canonicalize to constant on RHS.
16590   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
16591       !DAG.isConstantFPBuildVectorOrConstantFP(N1))
16592     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
16593 
16594   if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1)) {
16595     const APFloat &AF = N1CFP->getValueAPF();
16596 
16597     // minnum(X, nan) -> X
16598     // maxnum(X, nan) -> X
16599     // minimum(X, nan) -> nan
16600     // maximum(X, nan) -> nan
16601     if (AF.isNaN())
16602       return PropagatesNaN ? N->getOperand(1) : N->getOperand(0);
16603 
16604     // In the following folds, inf can be replaced with the largest finite
16605     // float, if the ninf flag is set.
16606     if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
16607       // minnum(X, -inf) -> -inf
16608       // maxnum(X, +inf) -> +inf
16609       // minimum(X, -inf) -> -inf if nnan
16610       // maximum(X, +inf) -> +inf if nnan
16611       if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
16612         return N->getOperand(1);
16613 
16614       // minnum(X, +inf) -> X if nnan
16615       // maxnum(X, -inf) -> X if nnan
16616       // minimum(X, +inf) -> X
16617       // maximum(X, -inf) -> X
16618       if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
16619         return N->getOperand(0);
16620     }
16621   }
16622 
16623   return SDValue();
16624 }
16625 
visitFABS(SDNode * N)16626 SDValue DAGCombiner::visitFABS(SDNode *N) {
16627   SDValue N0 = N->getOperand(0);
16628   EVT VT = N->getValueType(0);
16629 
16630   // fold (fabs c1) -> fabs(c1)
16631   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16632     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
16633 
16634   // fold (fabs (fabs x)) -> (fabs x)
16635   if (N0.getOpcode() == ISD::FABS)
16636     return N->getOperand(0);
16637 
16638   // fold (fabs (fneg x)) -> (fabs x)
16639   // fold (fabs (fcopysign x, y)) -> (fabs x)
16640   if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
16641     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0.getOperand(0));
16642 
16643   if (SDValue Cast = foldSignChangeInBitcast(N))
16644     return Cast;
16645 
16646   return SDValue();
16647 }
16648 
visitBRCOND(SDNode * N)16649 SDValue DAGCombiner::visitBRCOND(SDNode *N) {
16650   SDValue Chain = N->getOperand(0);
16651   SDValue N1 = N->getOperand(1);
16652   SDValue N2 = N->getOperand(2);
16653 
16654   // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
16655   // nondeterministic jumps).
16656   if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
16657     return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
16658                        N1->getOperand(0), N2);
16659   }
16660 
16661   // If N is a constant we could fold this into a fallthrough or unconditional
16662   // branch. However that doesn't happen very often in normal code, because
16663   // Instcombine/SimplifyCFG should have handled the available opportunities.
16664   // If we did this folding here, it would be necessary to update the
16665   // MachineBasicBlock CFG, which is awkward.
16666 
16667   // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
16668   // on the target.
16669   if (N1.getOpcode() == ISD::SETCC &&
16670       TLI.isOperationLegalOrCustom(ISD::BR_CC,
16671                                    N1.getOperand(0).getValueType())) {
16672     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
16673                        Chain, N1.getOperand(2),
16674                        N1.getOperand(0), N1.getOperand(1), N2);
16675   }
16676 
16677   if (N1.hasOneUse()) {
16678     // rebuildSetCC calls visitXor which may change the Chain when there is a
16679     // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
16680     HandleSDNode ChainHandle(Chain);
16681     if (SDValue NewN1 = rebuildSetCC(N1))
16682       return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
16683                          ChainHandle.getValue(), NewN1, N2);
16684   }
16685 
16686   return SDValue();
16687 }
16688 
rebuildSetCC(SDValue N)16689 SDValue DAGCombiner::rebuildSetCC(SDValue N) {
16690   if (N.getOpcode() == ISD::SRL ||
16691       (N.getOpcode() == ISD::TRUNCATE &&
16692        (N.getOperand(0).hasOneUse() &&
16693         N.getOperand(0).getOpcode() == ISD::SRL))) {
16694     // Look pass the truncate.
16695     if (N.getOpcode() == ISD::TRUNCATE)
16696       N = N.getOperand(0);
16697 
16698     // Match this pattern so that we can generate simpler code:
16699     //
16700     //   %a = ...
16701     //   %b = and i32 %a, 2
16702     //   %c = srl i32 %b, 1
16703     //   brcond i32 %c ...
16704     //
16705     // into
16706     //
16707     //   %a = ...
16708     //   %b = and i32 %a, 2
16709     //   %c = setcc eq %b, 0
16710     //   brcond %c ...
16711     //
16712     // This applies only when the AND constant value has one bit set and the
16713     // SRL constant is equal to the log2 of the AND constant. The back-end is
16714     // smart enough to convert the result into a TEST/JMP sequence.
16715     SDValue Op0 = N.getOperand(0);
16716     SDValue Op1 = N.getOperand(1);
16717 
16718     if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
16719       SDValue AndOp1 = Op0.getOperand(1);
16720 
16721       if (AndOp1.getOpcode() == ISD::Constant) {
16722         const APInt &AndConst = cast<ConstantSDNode>(AndOp1)->getAPIntValue();
16723 
16724         if (AndConst.isPowerOf2() &&
16725             cast<ConstantSDNode>(Op1)->getAPIntValue() == AndConst.logBase2()) {
16726           SDLoc DL(N);
16727           return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()),
16728                               Op0, DAG.getConstant(0, DL, Op0.getValueType()),
16729                               ISD::SETNE);
16730         }
16731       }
16732     }
16733   }
16734 
16735   // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
16736   // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
16737   if (N.getOpcode() == ISD::XOR) {
16738     // Because we may call this on a speculatively constructed
16739     // SimplifiedSetCC Node, we need to simplify this node first.
16740     // Ideally this should be folded into SimplifySetCC and not
16741     // here. For now, grab a handle to N so we don't lose it from
16742     // replacements interal to the visit.
16743     HandleSDNode XORHandle(N);
16744     while (N.getOpcode() == ISD::XOR) {
16745       SDValue Tmp = visitXOR(N.getNode());
16746       // No simplification done.
16747       if (!Tmp.getNode())
16748         break;
16749       // Returning N is form in-visit replacement that may invalidated
16750       // N. Grab value from Handle.
16751       if (Tmp.getNode() == N.getNode())
16752         N = XORHandle.getValue();
16753       else // Node simplified. Try simplifying again.
16754         N = Tmp;
16755     }
16756 
16757     if (N.getOpcode() != ISD::XOR)
16758       return N;
16759 
16760     SDValue Op0 = N->getOperand(0);
16761     SDValue Op1 = N->getOperand(1);
16762 
16763     if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
16764       bool Equal = false;
16765       // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
16766       if (isBitwiseNot(N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
16767           Op0.getValueType() == MVT::i1) {
16768         N = Op0;
16769         Op0 = N->getOperand(0);
16770         Op1 = N->getOperand(1);
16771         Equal = true;
16772       }
16773 
16774       EVT SetCCVT = N.getValueType();
16775       if (LegalTypes)
16776         SetCCVT = getSetCCResultType(SetCCVT);
16777       // Replace the uses of XOR with SETCC
16778       return DAG.getSetCC(SDLoc(N), SetCCVT, Op0, Op1,
16779                           Equal ? ISD::SETEQ : ISD::SETNE);
16780     }
16781   }
16782 
16783   return SDValue();
16784 }
16785 
16786 // Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
16787 //
visitBR_CC(SDNode * N)16788 SDValue DAGCombiner::visitBR_CC(SDNode *N) {
16789   CondCodeSDNode *CC = cast<CondCodeSDNode>(N->getOperand(1));
16790   SDValue CondLHS = N->getOperand(2), CondRHS = N->getOperand(3);
16791 
16792   // If N is a constant we could fold this into a fallthrough or unconditional
16793   // branch. However that doesn't happen very often in normal code, because
16794   // Instcombine/SimplifyCFG should have handled the available opportunities.
16795   // If we did this folding here, it would be necessary to update the
16796   // MachineBasicBlock CFG, which is awkward.
16797 
16798   // Use SimplifySetCC to simplify SETCC's.
16799   SDValue Simp = SimplifySetCC(getSetCCResultType(CondLHS.getValueType()),
16800                                CondLHS, CondRHS, CC->get(), SDLoc(N),
16801                                false);
16802   if (Simp.getNode()) AddToWorklist(Simp.getNode());
16803 
16804   // fold to a simpler setcc
16805   if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
16806     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
16807                        N->getOperand(0), Simp.getOperand(2),
16808                        Simp.getOperand(0), Simp.getOperand(1),
16809                        N->getOperand(4));
16810 
16811   return SDValue();
16812 }
16813 
getCombineLoadStoreParts(SDNode * N,unsigned Inc,unsigned Dec,bool & IsLoad,bool & IsMasked,SDValue & Ptr,const TargetLowering & TLI)16814 static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
16815                                      bool &IsLoad, bool &IsMasked, SDValue &Ptr,
16816                                      const TargetLowering &TLI) {
16817   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
16818     if (LD->isIndexed())
16819       return false;
16820     EVT VT = LD->getMemoryVT();
16821     if (!TLI.isIndexedLoadLegal(Inc, VT) && !TLI.isIndexedLoadLegal(Dec, VT))
16822       return false;
16823     Ptr = LD->getBasePtr();
16824   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
16825     if (ST->isIndexed())
16826       return false;
16827     EVT VT = ST->getMemoryVT();
16828     if (!TLI.isIndexedStoreLegal(Inc, VT) && !TLI.isIndexedStoreLegal(Dec, VT))
16829       return false;
16830     Ptr = ST->getBasePtr();
16831     IsLoad = false;
16832   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) {
16833     if (LD->isIndexed())
16834       return false;
16835     EVT VT = LD->getMemoryVT();
16836     if (!TLI.isIndexedMaskedLoadLegal(Inc, VT) &&
16837         !TLI.isIndexedMaskedLoadLegal(Dec, VT))
16838       return false;
16839     Ptr = LD->getBasePtr();
16840     IsMasked = true;
16841   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) {
16842     if (ST->isIndexed())
16843       return false;
16844     EVT VT = ST->getMemoryVT();
16845     if (!TLI.isIndexedMaskedStoreLegal(Inc, VT) &&
16846         !TLI.isIndexedMaskedStoreLegal(Dec, VT))
16847       return false;
16848     Ptr = ST->getBasePtr();
16849     IsLoad = false;
16850     IsMasked = true;
16851   } else {
16852     return false;
16853   }
16854   return true;
16855 }
16856 
16857 /// Try turning a load/store into a pre-indexed load/store when the base
16858 /// pointer is an add or subtract and it has other uses besides the load/store.
16859 /// After the transformation, the new indexed load/store has effectively folded
16860 /// the add/subtract in and all of its other uses are redirected to the
16861 /// new load/store.
CombineToPreIndexedLoadStore(SDNode * N)16862 bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
16863   if (Level < AfterLegalizeDAG)
16864     return false;
16865 
16866   bool IsLoad = true;
16867   bool IsMasked = false;
16868   SDValue Ptr;
16869   if (!getCombineLoadStoreParts(N, ISD::PRE_INC, ISD::PRE_DEC, IsLoad, IsMasked,
16870                                 Ptr, TLI))
16871     return false;
16872 
16873   // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
16874   // out.  There is no reason to make this a preinc/predec.
16875   if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
16876       Ptr->hasOneUse())
16877     return false;
16878 
16879   // Ask the target to do addressing mode selection.
16880   SDValue BasePtr;
16881   SDValue Offset;
16882   ISD::MemIndexedMode AM = ISD::UNINDEXED;
16883   if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
16884     return false;
16885 
16886   // Backends without true r+i pre-indexed forms may need to pass a
16887   // constant base with a variable offset so that constant coercion
16888   // will work with the patterns in canonical form.
16889   bool Swapped = false;
16890   if (isa<ConstantSDNode>(BasePtr)) {
16891     std::swap(BasePtr, Offset);
16892     Swapped = true;
16893   }
16894 
16895   // Don't create a indexed load / store with zero offset.
16896   if (isNullConstant(Offset))
16897     return false;
16898 
16899   // Try turning it into a pre-indexed load / store except when:
16900   // 1) The new base ptr is a frame index.
16901   // 2) If N is a store and the new base ptr is either the same as or is a
16902   //    predecessor of the value being stored.
16903   // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
16904   //    that would create a cycle.
16905   // 4) All uses are load / store ops that use it as old base ptr.
16906 
16907   // Check #1.  Preinc'ing a frame index would require copying the stack pointer
16908   // (plus the implicit offset) to a register to preinc anyway.
16909   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
16910     return false;
16911 
16912   // Check #2.
16913   if (!IsLoad) {
16914     SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(N)->getValue()
16915                            : cast<StoreSDNode>(N)->getValue();
16916 
16917     // Would require a copy.
16918     if (Val == BasePtr)
16919       return false;
16920 
16921     // Would create a cycle.
16922     if (Val == Ptr || Ptr->isPredecessorOf(Val.getNode()))
16923       return false;
16924   }
16925 
16926   // Caches for hasPredecessorHelper.
16927   SmallPtrSet<const SDNode *, 32> Visited;
16928   SmallVector<const SDNode *, 16> Worklist;
16929   Worklist.push_back(N);
16930 
16931   // If the offset is a constant, there may be other adds of constants that
16932   // can be folded with this one. We should do this to avoid having to keep
16933   // a copy of the original base pointer.
16934   SmallVector<SDNode *, 16> OtherUses;
16935   if (isa<ConstantSDNode>(Offset))
16936     for (SDNode::use_iterator UI = BasePtr->use_begin(),
16937                               UE = BasePtr->use_end();
16938          UI != UE; ++UI) {
16939       SDUse &Use = UI.getUse();
16940       // Skip the use that is Ptr and uses of other results from BasePtr's
16941       // node (important for nodes that return multiple results).
16942       if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
16943         continue;
16944 
16945       if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist))
16946         continue;
16947 
16948       if (Use.getUser()->getOpcode() != ISD::ADD &&
16949           Use.getUser()->getOpcode() != ISD::SUB) {
16950         OtherUses.clear();
16951         break;
16952       }
16953 
16954       SDValue Op1 = Use.getUser()->getOperand((UI.getOperandNo() + 1) & 1);
16955       if (!isa<ConstantSDNode>(Op1)) {
16956         OtherUses.clear();
16957         break;
16958       }
16959 
16960       // FIXME: In some cases, we can be smarter about this.
16961       if (Op1.getValueType() != Offset.getValueType()) {
16962         OtherUses.clear();
16963         break;
16964       }
16965 
16966       OtherUses.push_back(Use.getUser());
16967     }
16968 
16969   if (Swapped)
16970     std::swap(BasePtr, Offset);
16971 
16972   // Now check for #3 and #4.
16973   bool RealUse = false;
16974 
16975   for (SDNode *Use : Ptr->uses()) {
16976     if (Use == N)
16977       continue;
16978     if (SDNode::hasPredecessorHelper(Use, Visited, Worklist))
16979       return false;
16980 
16981     // If Ptr may be folded in addressing mode of other use, then it's
16982     // not profitable to do this transformation.
16983     if (!canFoldInAddressingMode(Ptr.getNode(), Use, DAG, TLI))
16984       RealUse = true;
16985   }
16986 
16987   if (!RealUse)
16988     return false;
16989 
16990   SDValue Result;
16991   if (!IsMasked) {
16992     if (IsLoad)
16993       Result = DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
16994     else
16995       Result =
16996           DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
16997   } else {
16998     if (IsLoad)
16999       Result = DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
17000                                         Offset, AM);
17001     else
17002       Result = DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), BasePtr,
17003                                          Offset, AM);
17004   }
17005   ++PreIndexedNodes;
17006   ++NodesCombined;
17007   LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
17008              Result.dump(&DAG); dbgs() << '\n');
17009   WorklistRemover DeadNodes(*this);
17010   if (IsLoad) {
17011     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
17012     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
17013   } else {
17014     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
17015   }
17016 
17017   // Finally, since the node is now dead, remove it from the graph.
17018   deleteAndRecombine(N);
17019 
17020   if (Swapped)
17021     std::swap(BasePtr, Offset);
17022 
17023   // Replace other uses of BasePtr that can be updated to use Ptr
17024   for (unsigned i = 0, e = OtherUses.size(); i != e; ++i) {
17025     unsigned OffsetIdx = 1;
17026     if (OtherUses[i]->getOperand(OffsetIdx).getNode() == BasePtr.getNode())
17027       OffsetIdx = 0;
17028     assert(OtherUses[i]->getOperand(!OffsetIdx).getNode() ==
17029            BasePtr.getNode() && "Expected BasePtr operand");
17030 
17031     // We need to replace ptr0 in the following expression:
17032     //   x0 * offset0 + y0 * ptr0 = t0
17033     // knowing that
17034     //   x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
17035     //
17036     // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
17037     // indexed load/store and the expression that needs to be re-written.
17038     //
17039     // Therefore, we have:
17040     //   t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
17041 
17042     auto *CN = cast<ConstantSDNode>(OtherUses[i]->getOperand(OffsetIdx));
17043     const APInt &Offset0 = CN->getAPIntValue();
17044     const APInt &Offset1 = cast<ConstantSDNode>(Offset)->getAPIntValue();
17045     int X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
17046     int Y0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
17047     int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
17048     int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
17049 
17050     unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
17051 
17052     APInt CNV = Offset0;
17053     if (X0 < 0) CNV = -CNV;
17054     if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
17055     else CNV = CNV - Offset1;
17056 
17057     SDLoc DL(OtherUses[i]);
17058 
17059     // We can now generate the new expression.
17060     SDValue NewOp1 = DAG.getConstant(CNV, DL, CN->getValueType(0));
17061     SDValue NewOp2 = Result.getValue(IsLoad ? 1 : 0);
17062 
17063     SDValue NewUse = DAG.getNode(Opcode,
17064                                  DL,
17065                                  OtherUses[i]->getValueType(0), NewOp1, NewOp2);
17066     DAG.ReplaceAllUsesOfValueWith(SDValue(OtherUses[i], 0), NewUse);
17067     deleteAndRecombine(OtherUses[i]);
17068   }
17069 
17070   // Replace the uses of Ptr with uses of the updated base value.
17071   DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(IsLoad ? 1 : 0));
17072   deleteAndRecombine(Ptr.getNode());
17073   AddToWorklist(Result.getNode());
17074 
17075   return true;
17076 }
17077 
shouldCombineToPostInc(SDNode * N,SDValue Ptr,SDNode * PtrUse,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)17078 static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
17079                                    SDValue &BasePtr, SDValue &Offset,
17080                                    ISD::MemIndexedMode &AM,
17081                                    SelectionDAG &DAG,
17082                                    const TargetLowering &TLI) {
17083   if (PtrUse == N ||
17084       (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
17085     return false;
17086 
17087   if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
17088     return false;
17089 
17090   // Don't create a indexed load / store with zero offset.
17091   if (isNullConstant(Offset))
17092     return false;
17093 
17094   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
17095     return false;
17096 
17097   SmallPtrSet<const SDNode *, 32> Visited;
17098   for (SDNode *Use : BasePtr->uses()) {
17099     if (Use == Ptr.getNode())
17100       continue;
17101 
17102     // No if there's a later user which could perform the index instead.
17103     if (isa<MemSDNode>(Use)) {
17104       bool IsLoad = true;
17105       bool IsMasked = false;
17106       SDValue OtherPtr;
17107       if (getCombineLoadStoreParts(Use, ISD::POST_INC, ISD::POST_DEC, IsLoad,
17108                                    IsMasked, OtherPtr, TLI)) {
17109         SmallVector<const SDNode *, 2> Worklist;
17110         Worklist.push_back(Use);
17111         if (SDNode::hasPredecessorHelper(N, Visited, Worklist))
17112           return false;
17113       }
17114     }
17115 
17116     // If all the uses are load / store addresses, then don't do the
17117     // transformation.
17118     if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) {
17119       for (SDNode *UseUse : Use->uses())
17120         if (canFoldInAddressingMode(Use, UseUse, DAG, TLI))
17121           return false;
17122     }
17123   }
17124   return true;
17125 }
17126 
getPostIndexedLoadStoreOp(SDNode * N,bool & IsLoad,bool & IsMasked,SDValue & Ptr,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)17127 static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
17128                                          bool &IsMasked, SDValue &Ptr,
17129                                          SDValue &BasePtr, SDValue &Offset,
17130                                          ISD::MemIndexedMode &AM,
17131                                          SelectionDAG &DAG,
17132                                          const TargetLowering &TLI) {
17133   if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad,
17134                                 IsMasked, Ptr, TLI) ||
17135       Ptr->hasOneUse())
17136     return nullptr;
17137 
17138   // Try turning it into a post-indexed load / store except when
17139   // 1) All uses are load / store ops that use it as base ptr (and
17140   //    it may be folded as addressing mmode).
17141   // 2) Op must be independent of N, i.e. Op is neither a predecessor
17142   //    nor a successor of N. Otherwise, if Op is folded that would
17143   //    create a cycle.
17144   for (SDNode *Op : Ptr->uses()) {
17145     // Check for #1.
17146     if (!shouldCombineToPostInc(N, Ptr, Op, BasePtr, Offset, AM, DAG, TLI))
17147       continue;
17148 
17149     // Check for #2.
17150     SmallPtrSet<const SDNode *, 32> Visited;
17151     SmallVector<const SDNode *, 8> Worklist;
17152     // Ptr is predecessor to both N and Op.
17153     Visited.insert(Ptr.getNode());
17154     Worklist.push_back(N);
17155     Worklist.push_back(Op);
17156     if (!SDNode::hasPredecessorHelper(N, Visited, Worklist) &&
17157         !SDNode::hasPredecessorHelper(Op, Visited, Worklist))
17158       return Op;
17159   }
17160   return nullptr;
17161 }
17162 
17163 /// Try to combine a load/store with a add/sub of the base pointer node into a
17164 /// post-indexed load/store. The transformation folded the add/subtract into the
17165 /// new indexed load/store effectively and all of its uses are redirected to the
17166 /// new load/store.
CombineToPostIndexedLoadStore(SDNode * N)17167 bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
17168   if (Level < AfterLegalizeDAG)
17169     return false;
17170 
17171   bool IsLoad = true;
17172   bool IsMasked = false;
17173   SDValue Ptr;
17174   SDValue BasePtr;
17175   SDValue Offset;
17176   ISD::MemIndexedMode AM = ISD::UNINDEXED;
17177   SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
17178                                          Offset, AM, DAG, TLI);
17179   if (!Op)
17180     return false;
17181 
17182   SDValue Result;
17183   if (!IsMasked)
17184     Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
17185                                          Offset, AM)
17186                     : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N),
17187                                           BasePtr, Offset, AM);
17188   else
17189     Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N),
17190                                                BasePtr, Offset, AM)
17191                     : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N),
17192                                                 BasePtr, Offset, AM);
17193   ++PostIndexedNodes;
17194   ++NodesCombined;
17195   LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: ";
17196              Result.dump(&DAG); dbgs() << '\n');
17197   WorklistRemover DeadNodes(*this);
17198   if (IsLoad) {
17199     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
17200     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
17201   } else {
17202     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
17203   }
17204 
17205   // Finally, since the node is now dead, remove it from the graph.
17206   deleteAndRecombine(N);
17207 
17208   // Replace the uses of Use with uses of the updated base value.
17209   DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0),
17210                                 Result.getValue(IsLoad ? 1 : 0));
17211   deleteAndRecombine(Op);
17212   return true;
17213 }
17214 
17215 /// Return the base-pointer arithmetic from an indexed \p LD.
SplitIndexingFromLoad(LoadSDNode * LD)17216 SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
17217   ISD::MemIndexedMode AM = LD->getAddressingMode();
17218   assert(AM != ISD::UNINDEXED);
17219   SDValue BP = LD->getOperand(1);
17220   SDValue Inc = LD->getOperand(2);
17221 
17222   // Some backends use TargetConstants for load offsets, but don't expect
17223   // TargetConstants in general ADD nodes. We can convert these constants into
17224   // regular Constants (if the constant is not opaque).
17225   assert((Inc.getOpcode() != ISD::TargetConstant ||
17226           !cast<ConstantSDNode>(Inc)->isOpaque()) &&
17227          "Cannot split out indexing using opaque target constants");
17228   if (Inc.getOpcode() == ISD::TargetConstant) {
17229     ConstantSDNode *ConstInc = cast<ConstantSDNode>(Inc);
17230     Inc = DAG.getConstant(*ConstInc->getConstantIntValue(), SDLoc(Inc),
17231                           ConstInc->getValueType(0));
17232   }
17233 
17234   unsigned Opc =
17235       (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
17236   return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc);
17237 }
17238 
numVectorEltsOrZero(EVT T)17239 static inline ElementCount numVectorEltsOrZero(EVT T) {
17240   return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(0);
17241 }
17242 
getTruncatedStoreValue(StoreSDNode * ST,SDValue & Val)17243 bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
17244   EVT STType = Val.getValueType();
17245   EVT STMemType = ST->getMemoryVT();
17246   if (STType == STMemType)
17247     return true;
17248   if (isTypeLegal(STMemType))
17249     return false; // fail.
17250   if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
17251       TLI.isOperationLegal(ISD::FTRUNC, STMemType)) {
17252     Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val);
17253     return true;
17254   }
17255   if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) &&
17256       STType.isInteger() && STMemType.isInteger()) {
17257     Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val);
17258     return true;
17259   }
17260   if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
17261     Val = DAG.getBitcast(STMemType, Val);
17262     return true;
17263   }
17264   return false; // fail.
17265 }
17266 
extendLoadedValueToExtension(LoadSDNode * LD,SDValue & Val)17267 bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
17268   EVT LDMemType = LD->getMemoryVT();
17269   EVT LDType = LD->getValueType(0);
17270   assert(Val.getValueType() == LDMemType &&
17271          "Attempting to extend value of non-matching type");
17272   if (LDType == LDMemType)
17273     return true;
17274   if (LDMemType.isInteger() && LDType.isInteger()) {
17275     switch (LD->getExtensionType()) {
17276     case ISD::NON_EXTLOAD:
17277       Val = DAG.getBitcast(LDType, Val);
17278       return true;
17279     case ISD::EXTLOAD:
17280       Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val);
17281       return true;
17282     case ISD::SEXTLOAD:
17283       Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val);
17284       return true;
17285     case ISD::ZEXTLOAD:
17286       Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val);
17287       return true;
17288     }
17289   }
17290   return false;
17291 }
17292 
ForwardStoreValueToDirectLoad(LoadSDNode * LD)17293 SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
17294   if (OptLevel == CodeGenOpt::None || !LD->isSimple())
17295     return SDValue();
17296   SDValue Chain = LD->getOperand(0);
17297   StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain.getNode());
17298   // TODO: Relax this restriction for unordered atomics (see D66309)
17299   if (!ST || !ST->isSimple() || ST->getAddressSpace() != LD->getAddressSpace())
17300     return SDValue();
17301 
17302   EVT LDType = LD->getValueType(0);
17303   EVT LDMemType = LD->getMemoryVT();
17304   EVT STMemType = ST->getMemoryVT();
17305   EVT STType = ST->getValue().getValueType();
17306 
17307   // There are two cases to consider here:
17308   //  1. The store is fixed width and the load is scalable. In this case we
17309   //     don't know at compile time if the store completely envelops the load
17310   //     so we abandon the optimisation.
17311   //  2. The store is scalable and the load is fixed width. We could
17312   //     potentially support a limited number of cases here, but there has been
17313   //     no cost-benefit analysis to prove it's worth it.
17314   bool LdStScalable = LDMemType.isScalableVector();
17315   if (LdStScalable != STMemType.isScalableVector())
17316     return SDValue();
17317 
17318   // If we are dealing with scalable vectors on a big endian platform the
17319   // calculation of offsets below becomes trickier, since we do not know at
17320   // compile time the absolute size of the vector. Until we've done more
17321   // analysis on big-endian platforms it seems better to bail out for now.
17322   if (LdStScalable && DAG.getDataLayout().isBigEndian())
17323     return SDValue();
17324 
17325   BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
17326   BaseIndexOffset BasePtrST = BaseIndexOffset::match(ST, DAG);
17327   int64_t Offset;
17328   if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
17329     return SDValue();
17330 
17331   // Normalize for Endianness. After this Offset=0 will denote that the least
17332   // significant bit in the loaded value maps to the least significant bit in
17333   // the stored value). With Offset=n (for n > 0) the loaded value starts at the
17334   // n:th least significant byte of the stored value.
17335   int64_t OrigOffset = Offset;
17336   if (DAG.getDataLayout().isBigEndian())
17337     Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedValue() -
17338               (int64_t)LDMemType.getStoreSizeInBits().getFixedValue()) /
17339                  8 -
17340              Offset;
17341 
17342   // Check that the stored value cover all bits that are loaded.
17343   bool STCoversLD;
17344 
17345   TypeSize LdMemSize = LDMemType.getSizeInBits();
17346   TypeSize StMemSize = STMemType.getSizeInBits();
17347   if (LdStScalable)
17348     STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
17349   else
17350     STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedValue() <=
17351                                    StMemSize.getFixedValue());
17352 
17353   auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
17354     if (LD->isIndexed()) {
17355       // Cannot handle opaque target constants and we must respect the user's
17356       // request not to split indexes from loads.
17357       if (!canSplitIdx(LD))
17358         return SDValue();
17359       SDValue Idx = SplitIndexingFromLoad(LD);
17360       SDValue Ops[] = {Val, Idx, Chain};
17361       return CombineTo(LD, Ops, 3);
17362     }
17363     return CombineTo(LD, Val, Chain);
17364   };
17365 
17366   if (!STCoversLD)
17367     return SDValue();
17368 
17369   // Memory as copy space (potentially masked).
17370   if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
17371     // Simple case: Direct non-truncating forwarding
17372     if (LDType.getSizeInBits() == LdMemSize)
17373       return ReplaceLd(LD, ST->getValue(), Chain);
17374     // Can we model the truncate and extension with an and mask?
17375     if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
17376         !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
17377       // Mask to size of LDMemType
17378       auto Mask =
17379           DAG.getConstant(APInt::getLowBitsSet(STType.getFixedSizeInBits(),
17380                                                StMemSize.getFixedValue()),
17381                           SDLoc(ST), STType);
17382       auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask);
17383       return ReplaceLd(LD, Val, Chain);
17384     }
17385   }
17386 
17387   // Handle some cases for big-endian that would be Offset 0 and handled for
17388   // little-endian.
17389   SDValue Val = ST->getValue();
17390   if (DAG.getDataLayout().isBigEndian() && Offset > 0 && OrigOffset == 0) {
17391     if (STType.isInteger() && !STType.isVector() && LDType.isInteger() &&
17392         !LDType.isVector() && isTypeLegal(STType) &&
17393         TLI.isOperationLegal(ISD::SRL, STType)) {
17394       Val = DAG.getNode(ISD::SRL, SDLoc(LD), STType, Val,
17395                         DAG.getConstant(Offset * 8, SDLoc(LD), STType));
17396       Offset = 0;
17397     }
17398   }
17399 
17400   // TODO: Deal with nonzero offset.
17401   if (LD->getBasePtr().isUndef() || Offset != 0)
17402     return SDValue();
17403   // Model necessary truncations / extenstions.
17404   // Truncate Value To Stored Memory Size.
17405   do {
17406     if (!getTruncatedStoreValue(ST, Val))
17407       continue;
17408     if (!isTypeLegal(LDMemType))
17409       continue;
17410     if (STMemType != LDMemType) {
17411       // TODO: Support vectors? This requires extract_subvector/bitcast.
17412       if (!STMemType.isVector() && !LDMemType.isVector() &&
17413           STMemType.isInteger() && LDMemType.isInteger())
17414         Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
17415       else
17416         continue;
17417     }
17418     if (!extendLoadedValueToExtension(LD, Val))
17419       continue;
17420     return ReplaceLd(LD, Val, Chain);
17421   } while (false);
17422 
17423   // On failure, cleanup dead nodes we may have created.
17424   if (Val->use_empty())
17425     deleteAndRecombine(Val.getNode());
17426   return SDValue();
17427 }
17428 
visitLOAD(SDNode * N)17429 SDValue DAGCombiner::visitLOAD(SDNode *N) {
17430   LoadSDNode *LD  = cast<LoadSDNode>(N);
17431   SDValue Chain = LD->getChain();
17432   SDValue Ptr   = LD->getBasePtr();
17433 
17434   // If load is not volatile and there are no uses of the loaded value (and
17435   // the updated indexed value in case of indexed loads), change uses of the
17436   // chain value into uses of the chain input (i.e. delete the dead load).
17437   // TODO: Allow this for unordered atomics (see D66309)
17438   if (LD->isSimple()) {
17439     if (N->getValueType(1) == MVT::Other) {
17440       // Unindexed loads.
17441       if (!N->hasAnyUseOfValue(0)) {
17442         // It's not safe to use the two value CombineTo variant here. e.g.
17443         // v1, chain2 = load chain1, loc
17444         // v2, chain3 = load chain2, loc
17445         // v3         = add v2, c
17446         // Now we replace use of chain2 with chain1.  This makes the second load
17447         // isomorphic to the one we are deleting, and thus makes this load live.
17448         LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
17449                    dbgs() << "\nWith chain: "; Chain.dump(&DAG);
17450                    dbgs() << "\n");
17451         WorklistRemover DeadNodes(*this);
17452         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
17453         AddUsersToWorklist(Chain.getNode());
17454         if (N->use_empty())
17455           deleteAndRecombine(N);
17456 
17457         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
17458       }
17459     } else {
17460       // Indexed loads.
17461       assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
17462 
17463       // If this load has an opaque TargetConstant offset, then we cannot split
17464       // the indexing into an add/sub directly (that TargetConstant may not be
17465       // valid for a different type of node, and we cannot convert an opaque
17466       // target constant into a regular constant).
17467       bool CanSplitIdx = canSplitIdx(LD);
17468 
17469       if (!N->hasAnyUseOfValue(0) && (CanSplitIdx || !N->hasAnyUseOfValue(1))) {
17470         SDValue Undef = DAG.getUNDEF(N->getValueType(0));
17471         SDValue Index;
17472         if (N->hasAnyUseOfValue(1) && CanSplitIdx) {
17473           Index = SplitIndexingFromLoad(LD);
17474           // Try to fold the base pointer arithmetic into subsequent loads and
17475           // stores.
17476           AddUsersToWorklist(N);
17477         } else
17478           Index = DAG.getUNDEF(N->getValueType(1));
17479         LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
17480                    dbgs() << "\nWith: "; Undef.dump(&DAG);
17481                    dbgs() << " and 2 other values\n");
17482         WorklistRemover DeadNodes(*this);
17483         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef);
17484         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index);
17485         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 2), Chain);
17486         deleteAndRecombine(N);
17487         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
17488       }
17489     }
17490   }
17491 
17492   // If this load is directly stored, replace the load value with the stored
17493   // value.
17494   if (auto V = ForwardStoreValueToDirectLoad(LD))
17495     return V;
17496 
17497   // Try to infer better alignment information than the load already has.
17498   if (OptLevel != CodeGenOpt::None && LD->isUnindexed() && !LD->isAtomic()) {
17499     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
17500       if (*Alignment > LD->getAlign() &&
17501           isAligned(*Alignment, LD->getSrcValueOffset())) {
17502         SDValue NewLoad = DAG.getExtLoad(
17503             LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr,
17504             LD->getPointerInfo(), LD->getMemoryVT(), *Alignment,
17505             LD->getMemOperand()->getFlags(), LD->getAAInfo());
17506         // NewLoad will always be N as we are only refining the alignment
17507         assert(NewLoad.getNode() == N);
17508         (void)NewLoad;
17509       }
17510     }
17511   }
17512 
17513   if (LD->isUnindexed()) {
17514     // Walk up chain skipping non-aliasing memory nodes.
17515     SDValue BetterChain = FindBetterChain(LD, Chain);
17516 
17517     // If there is a better chain.
17518     if (Chain != BetterChain) {
17519       SDValue ReplLoad;
17520 
17521       // Replace the chain to void dependency.
17522       if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
17523         ReplLoad = DAG.getLoad(N->getValueType(0), SDLoc(LD),
17524                                BetterChain, Ptr, LD->getMemOperand());
17525       } else {
17526         ReplLoad = DAG.getExtLoad(LD->getExtensionType(), SDLoc(LD),
17527                                   LD->getValueType(0),
17528                                   BetterChain, Ptr, LD->getMemoryVT(),
17529                                   LD->getMemOperand());
17530       }
17531 
17532       // Create token factor to keep old chain connected.
17533       SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
17534                                   MVT::Other, Chain, ReplLoad.getValue(1));
17535 
17536       // Replace uses with load result and token factor
17537       return CombineTo(N, ReplLoad.getValue(0), Token);
17538     }
17539   }
17540 
17541   // Try transforming N to an indexed load.
17542   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
17543     return SDValue(N, 0);
17544 
17545   // Try to slice up N to more direct loads if the slices are mapped to
17546   // different register banks or pairing can take place.
17547   if (SliceUpLoad(N))
17548     return SDValue(N, 0);
17549 
17550   return SDValue();
17551 }
17552 
17553 namespace {
17554 
17555 /// Helper structure used to slice a load in smaller loads.
17556 /// Basically a slice is obtained from the following sequence:
17557 /// Origin = load Ty1, Base
17558 /// Shift = srl Ty1 Origin, CstTy Amount
17559 /// Inst = trunc Shift to Ty2
17560 ///
17561 /// Then, it will be rewritten into:
17562 /// Slice = load SliceTy, Base + SliceOffset
17563 /// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
17564 ///
17565 /// SliceTy is deduced from the number of bits that are actually used to
17566 /// build Inst.
17567 struct LoadedSlice {
17568   /// Helper structure used to compute the cost of a slice.
17569   struct Cost {
17570     /// Are we optimizing for code size.
17571     bool ForCodeSize = false;
17572 
17573     /// Various cost.
17574     unsigned Loads = 0;
17575     unsigned Truncates = 0;
17576     unsigned CrossRegisterBanksCopies = 0;
17577     unsigned ZExts = 0;
17578     unsigned Shift = 0;
17579 
Cost__anonbd6f1c503c11::LoadedSlice::Cost17580     explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
17581 
17582     /// Get the cost of one isolated slice.
Cost__anonbd6f1c503c11::LoadedSlice::Cost17583     Cost(const LoadedSlice &LS, bool ForCodeSize)
17584         : ForCodeSize(ForCodeSize), Loads(1) {
17585       EVT TruncType = LS.Inst->getValueType(0);
17586       EVT LoadedType = LS.getLoadedType();
17587       if (TruncType != LoadedType &&
17588           !LS.DAG->getTargetLoweringInfo().isZExtFree(LoadedType, TruncType))
17589         ZExts = 1;
17590     }
17591 
17592     /// Account for slicing gain in the current cost.
17593     /// Slicing provide a few gains like removing a shift or a
17594     /// truncate. This method allows to grow the cost of the original
17595     /// load with the gain from this slice.
addSliceGain__anonbd6f1c503c11::LoadedSlice::Cost17596     void addSliceGain(const LoadedSlice &LS) {
17597       // Each slice saves a truncate.
17598       const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
17599       if (!TLI.isTruncateFree(LS.Inst->getOperand(0).getValueType(),
17600                               LS.Inst->getValueType(0)))
17601         ++Truncates;
17602       // If there is a shift amount, this slice gets rid of it.
17603       if (LS.Shift)
17604         ++Shift;
17605       // If this slice can merge a cross register bank copy, account for it.
17606       if (LS.canMergeExpensiveCrossRegisterBankCopy())
17607         ++CrossRegisterBanksCopies;
17608     }
17609 
operator +=__anonbd6f1c503c11::LoadedSlice::Cost17610     Cost &operator+=(const Cost &RHS) {
17611       Loads += RHS.Loads;
17612       Truncates += RHS.Truncates;
17613       CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
17614       ZExts += RHS.ZExts;
17615       Shift += RHS.Shift;
17616       return *this;
17617     }
17618 
operator ==__anonbd6f1c503c11::LoadedSlice::Cost17619     bool operator==(const Cost &RHS) const {
17620       return Loads == RHS.Loads && Truncates == RHS.Truncates &&
17621              CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
17622              ZExts == RHS.ZExts && Shift == RHS.Shift;
17623     }
17624 
operator !=__anonbd6f1c503c11::LoadedSlice::Cost17625     bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
17626 
operator <__anonbd6f1c503c11::LoadedSlice::Cost17627     bool operator<(const Cost &RHS) const {
17628       // Assume cross register banks copies are as expensive as loads.
17629       // FIXME: Do we want some more target hooks?
17630       unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
17631       unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
17632       // Unless we are optimizing for code size, consider the
17633       // expensive operation first.
17634       if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
17635         return ExpensiveOpsLHS < ExpensiveOpsRHS;
17636       return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
17637              (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
17638     }
17639 
operator >__anonbd6f1c503c11::LoadedSlice::Cost17640     bool operator>(const Cost &RHS) const { return RHS < *this; }
17641 
operator <=__anonbd6f1c503c11::LoadedSlice::Cost17642     bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
17643 
operator >=__anonbd6f1c503c11::LoadedSlice::Cost17644     bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
17645   };
17646 
17647   // The last instruction that represent the slice. This should be a
17648   // truncate instruction.
17649   SDNode *Inst;
17650 
17651   // The original load instruction.
17652   LoadSDNode *Origin;
17653 
17654   // The right shift amount in bits from the original load.
17655   unsigned Shift;
17656 
17657   // The DAG from which Origin came from.
17658   // This is used to get some contextual information about legal types, etc.
17659   SelectionDAG *DAG;
17660 
LoadedSlice__anonbd6f1c503c11::LoadedSlice17661   LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
17662               unsigned Shift = 0, SelectionDAG *DAG = nullptr)
17663       : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
17664 
17665   /// Get the bits used in a chunk of bits \p BitWidth large.
17666   /// \return Result is \p BitWidth and has used bits set to 1 and
17667   ///         not used bits set to 0.
getUsedBits__anonbd6f1c503c11::LoadedSlice17668   APInt getUsedBits() const {
17669     // Reproduce the trunc(lshr) sequence:
17670     // - Start from the truncated value.
17671     // - Zero extend to the desired bit width.
17672     // - Shift left.
17673     assert(Origin && "No original load to compare against.");
17674     unsigned BitWidth = Origin->getValueSizeInBits(0);
17675     assert(Inst && "This slice is not bound to an instruction");
17676     assert(Inst->getValueSizeInBits(0) <= BitWidth &&
17677            "Extracted slice is bigger than the whole type!");
17678     APInt UsedBits(Inst->getValueSizeInBits(0), 0);
17679     UsedBits.setAllBits();
17680     UsedBits = UsedBits.zext(BitWidth);
17681     UsedBits <<= Shift;
17682     return UsedBits;
17683   }
17684 
17685   /// Get the size of the slice to be loaded in bytes.
getLoadedSize__anonbd6f1c503c11::LoadedSlice17686   unsigned getLoadedSize() const {
17687     unsigned SliceSize = getUsedBits().countPopulation();
17688     assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
17689     return SliceSize / 8;
17690   }
17691 
17692   /// Get the type that will be loaded for this slice.
17693   /// Note: This may not be the final type for the slice.
getLoadedType__anonbd6f1c503c11::LoadedSlice17694   EVT getLoadedType() const {
17695     assert(DAG && "Missing context");
17696     LLVMContext &Ctxt = *DAG->getContext();
17697     return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8);
17698   }
17699 
17700   /// Get the alignment of the load used for this slice.
getAlign__anonbd6f1c503c11::LoadedSlice17701   Align getAlign() const {
17702     Align Alignment = Origin->getAlign();
17703     uint64_t Offset = getOffsetFromBase();
17704     if (Offset != 0)
17705       Alignment = commonAlignment(Alignment, Alignment.value() + Offset);
17706     return Alignment;
17707   }
17708 
17709   /// Check if this slice can be rewritten with legal operations.
isLegal__anonbd6f1c503c11::LoadedSlice17710   bool isLegal() const {
17711     // An invalid slice is not legal.
17712     if (!Origin || !Inst || !DAG)
17713       return false;
17714 
17715     // Offsets are for indexed load only, we do not handle that.
17716     if (!Origin->getOffset().isUndef())
17717       return false;
17718 
17719     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
17720 
17721     // Check that the type is legal.
17722     EVT SliceType = getLoadedType();
17723     if (!TLI.isTypeLegal(SliceType))
17724       return false;
17725 
17726     // Check that the load is legal for this type.
17727     if (!TLI.isOperationLegal(ISD::LOAD, SliceType))
17728       return false;
17729 
17730     // Check that the offset can be computed.
17731     // 1. Check its type.
17732     EVT PtrType = Origin->getBasePtr().getValueType();
17733     if (PtrType == MVT::Untyped || PtrType.isExtended())
17734       return false;
17735 
17736     // 2. Check that it fits in the immediate.
17737     if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
17738       return false;
17739 
17740     // 3. Check that the computation is legal.
17741     if (!TLI.isOperationLegal(ISD::ADD, PtrType))
17742       return false;
17743 
17744     // Check that the zext is legal if it needs one.
17745     EVT TruncateType = Inst->getValueType(0);
17746     if (TruncateType != SliceType &&
17747         !TLI.isOperationLegal(ISD::ZERO_EXTEND, TruncateType))
17748       return false;
17749 
17750     return true;
17751   }
17752 
17753   /// Get the offset in bytes of this slice in the original chunk of
17754   /// bits.
17755   /// \pre DAG != nullptr.
getOffsetFromBase__anonbd6f1c503c11::LoadedSlice17756   uint64_t getOffsetFromBase() const {
17757     assert(DAG && "Missing context.");
17758     bool IsBigEndian = DAG->getDataLayout().isBigEndian();
17759     assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
17760     uint64_t Offset = Shift / 8;
17761     unsigned TySizeInBytes = Origin->getValueSizeInBits(0) / 8;
17762     assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
17763            "The size of the original loaded type is not a multiple of a"
17764            " byte.");
17765     // If Offset is bigger than TySizeInBytes, it means we are loading all
17766     // zeros. This should have been optimized before in the process.
17767     assert(TySizeInBytes > Offset &&
17768            "Invalid shift amount for given loaded size");
17769     if (IsBigEndian)
17770       Offset = TySizeInBytes - Offset - getLoadedSize();
17771     return Offset;
17772   }
17773 
17774   /// Generate the sequence of instructions to load the slice
17775   /// represented by this object and redirect the uses of this slice to
17776   /// this new sequence of instructions.
17777   /// \pre this->Inst && this->Origin are valid Instructions and this
17778   /// object passed the legal check: LoadedSlice::isLegal returned true.
17779   /// \return The last instruction of the sequence used to load the slice.
loadSlice__anonbd6f1c503c11::LoadedSlice17780   SDValue loadSlice() const {
17781     assert(Inst && Origin && "Unable to replace a non-existing slice.");
17782     const SDValue &OldBaseAddr = Origin->getBasePtr();
17783     SDValue BaseAddr = OldBaseAddr;
17784     // Get the offset in that chunk of bytes w.r.t. the endianness.
17785     int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
17786     assert(Offset >= 0 && "Offset too big to fit in int64_t!");
17787     if (Offset) {
17788       // BaseAddr = BaseAddr + Offset.
17789       EVT ArithType = BaseAddr.getValueType();
17790       SDLoc DL(Origin);
17791       BaseAddr = DAG->getNode(ISD::ADD, DL, ArithType, BaseAddr,
17792                               DAG->getConstant(Offset, DL, ArithType));
17793     }
17794 
17795     // Create the type of the loaded slice according to its size.
17796     EVT SliceType = getLoadedType();
17797 
17798     // Create the load for the slice.
17799     SDValue LastInst =
17800         DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
17801                      Origin->getPointerInfo().getWithOffset(Offset), getAlign(),
17802                      Origin->getMemOperand()->getFlags());
17803     // If the final type is not the same as the loaded type, this means that
17804     // we have to pad with zero. Create a zero extend for that.
17805     EVT FinalType = Inst->getValueType(0);
17806     if (SliceType != FinalType)
17807       LastInst =
17808           DAG->getNode(ISD::ZERO_EXTEND, SDLoc(LastInst), FinalType, LastInst);
17809     return LastInst;
17810   }
17811 
17812   /// Check if this slice can be merged with an expensive cross register
17813   /// bank copy. E.g.,
17814   /// i = load i32
17815   /// f = bitcast i32 i to float
canMergeExpensiveCrossRegisterBankCopy__anonbd6f1c503c11::LoadedSlice17816   bool canMergeExpensiveCrossRegisterBankCopy() const {
17817     if (!Inst || !Inst->hasOneUse())
17818       return false;
17819     SDNode *Use = *Inst->use_begin();
17820     if (Use->getOpcode() != ISD::BITCAST)
17821       return false;
17822     assert(DAG && "Missing context");
17823     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
17824     EVT ResVT = Use->getValueType(0);
17825     const TargetRegisterClass *ResRC =
17826         TLI.getRegClassFor(ResVT.getSimpleVT(), Use->isDivergent());
17827     const TargetRegisterClass *ArgRC =
17828         TLI.getRegClassFor(Use->getOperand(0).getValueType().getSimpleVT(),
17829                            Use->getOperand(0)->isDivergent());
17830     if (ArgRC == ResRC || !TLI.isOperationLegal(ISD::LOAD, ResVT))
17831       return false;
17832 
17833     // At this point, we know that we perform a cross-register-bank copy.
17834     // Check if it is expensive.
17835     const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
17836     // Assume bitcasts are cheap, unless both register classes do not
17837     // explicitly share a common sub class.
17838     if (!TRI || TRI->getCommonSubClass(ArgRC, ResRC))
17839       return false;
17840 
17841     // Check if it will be merged with the load.
17842     // 1. Check the alignment / fast memory access constraint.
17843     unsigned IsFast = 0;
17844     if (!TLI.allowsMemoryAccess(*DAG->getContext(), DAG->getDataLayout(), ResVT,
17845                                 Origin->getAddressSpace(), getAlign(),
17846                                 Origin->getMemOperand()->getFlags(), &IsFast) ||
17847         !IsFast)
17848       return false;
17849 
17850     // 2. Check that the load is a legal operation for that type.
17851     if (!TLI.isOperationLegal(ISD::LOAD, ResVT))
17852       return false;
17853 
17854     // 3. Check that we do not have a zext in the way.
17855     if (Inst->getValueType(0) != getLoadedType())
17856       return false;
17857 
17858     return true;
17859   }
17860 };
17861 
17862 } // end anonymous namespace
17863 
17864 /// Check that all bits set in \p UsedBits form a dense region, i.e.,
17865 /// \p UsedBits looks like 0..0 1..1 0..0.
areUsedBitsDense(const APInt & UsedBits)17866 static bool areUsedBitsDense(const APInt &UsedBits) {
17867   // If all the bits are one, this is dense!
17868   if (UsedBits.isAllOnes())
17869     return true;
17870 
17871   // Get rid of the unused bits on the right.
17872   APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countTrailingZeros());
17873   // Get rid of the unused bits on the left.
17874   if (NarrowedUsedBits.countLeadingZeros())
17875     NarrowedUsedBits = NarrowedUsedBits.trunc(NarrowedUsedBits.getActiveBits());
17876   // Check that the chunk of bits is completely used.
17877   return NarrowedUsedBits.isAllOnes();
17878 }
17879 
17880 /// Check whether or not \p First and \p Second are next to each other
17881 /// in memory. This means that there is no hole between the bits loaded
17882 /// by \p First and the bits loaded by \p Second.
areSlicesNextToEachOther(const LoadedSlice & First,const LoadedSlice & Second)17883 static bool areSlicesNextToEachOther(const LoadedSlice &First,
17884                                      const LoadedSlice &Second) {
17885   assert(First.Origin == Second.Origin && First.Origin &&
17886          "Unable to match different memory origins.");
17887   APInt UsedBits = First.getUsedBits();
17888   assert((UsedBits & Second.getUsedBits()) == 0 &&
17889          "Slices are not supposed to overlap.");
17890   UsedBits |= Second.getUsedBits();
17891   return areUsedBitsDense(UsedBits);
17892 }
17893 
17894 /// Adjust the \p GlobalLSCost according to the target
17895 /// paring capabilities and the layout of the slices.
17896 /// \pre \p GlobalLSCost should account for at least as many loads as
17897 /// there is in the slices in \p LoadedSlices.
adjustCostForPairing(SmallVectorImpl<LoadedSlice> & LoadedSlices,LoadedSlice::Cost & GlobalLSCost)17898 static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
17899                                  LoadedSlice::Cost &GlobalLSCost) {
17900   unsigned NumberOfSlices = LoadedSlices.size();
17901   // If there is less than 2 elements, no pairing is possible.
17902   if (NumberOfSlices < 2)
17903     return;
17904 
17905   // Sort the slices so that elements that are likely to be next to each
17906   // other in memory are next to each other in the list.
17907   llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
17908     assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
17909     return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
17910   });
17911   const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
17912   // First (resp. Second) is the first (resp. Second) potentially candidate
17913   // to be placed in a paired load.
17914   const LoadedSlice *First = nullptr;
17915   const LoadedSlice *Second = nullptr;
17916   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
17917                 // Set the beginning of the pair.
17918                                                            First = Second) {
17919     Second = &LoadedSlices[CurrSlice];
17920 
17921     // If First is NULL, it means we start a new pair.
17922     // Get to the next slice.
17923     if (!First)
17924       continue;
17925 
17926     EVT LoadedType = First->getLoadedType();
17927 
17928     // If the types of the slices are different, we cannot pair them.
17929     if (LoadedType != Second->getLoadedType())
17930       continue;
17931 
17932     // Check if the target supplies paired loads for this type.
17933     Align RequiredAlignment;
17934     if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
17935       // move to the next pair, this type is hopeless.
17936       Second = nullptr;
17937       continue;
17938     }
17939     // Check if we meet the alignment requirement.
17940     if (First->getAlign() < RequiredAlignment)
17941       continue;
17942 
17943     // Check that both loads are next to each other in memory.
17944     if (!areSlicesNextToEachOther(*First, *Second))
17945       continue;
17946 
17947     assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
17948     --GlobalLSCost.Loads;
17949     // Move to the next pair.
17950     Second = nullptr;
17951   }
17952 }
17953 
17954 /// Check the profitability of all involved LoadedSlice.
17955 /// Currently, it is considered profitable if there is exactly two
17956 /// involved slices (1) which are (2) next to each other in memory, and
17957 /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
17958 ///
17959 /// Note: The order of the elements in \p LoadedSlices may be modified, but not
17960 /// the elements themselves.
17961 ///
17962 /// FIXME: When the cost model will be mature enough, we can relax
17963 /// constraints (1) and (2).
isSlicingProfitable(SmallVectorImpl<LoadedSlice> & LoadedSlices,const APInt & UsedBits,bool ForCodeSize)17964 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
17965                                 const APInt &UsedBits, bool ForCodeSize) {
17966   unsigned NumberOfSlices = LoadedSlices.size();
17967   if (StressLoadSlicing)
17968     return NumberOfSlices > 1;
17969 
17970   // Check (1).
17971   if (NumberOfSlices != 2)
17972     return false;
17973 
17974   // Check (2).
17975   if (!areUsedBitsDense(UsedBits))
17976     return false;
17977 
17978   // Check (3).
17979   LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
17980   // The original code has one big load.
17981   OrigCost.Loads = 1;
17982   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
17983     const LoadedSlice &LS = LoadedSlices[CurrSlice];
17984     // Accumulate the cost of all the slices.
17985     LoadedSlice::Cost SliceCost(LS, ForCodeSize);
17986     GlobalSlicingCost += SliceCost;
17987 
17988     // Account as cost in the original configuration the gain obtained
17989     // with the current slices.
17990     OrigCost.addSliceGain(LS);
17991   }
17992 
17993   // If the target supports paired load, adjust the cost accordingly.
17994   adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
17995   return OrigCost > GlobalSlicingCost;
17996 }
17997 
17998 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
17999 /// operations, split it in the various pieces being extracted.
18000 ///
18001 /// This sort of thing is introduced by SROA.
18002 /// This slicing takes care not to insert overlapping loads.
18003 /// \pre LI is a simple load (i.e., not an atomic or volatile load).
SliceUpLoad(SDNode * N)18004 bool DAGCombiner::SliceUpLoad(SDNode *N) {
18005   if (Level < AfterLegalizeDAG)
18006     return false;
18007 
18008   LoadSDNode *LD = cast<LoadSDNode>(N);
18009   if (!LD->isSimple() || !ISD::isNormalLoad(LD) ||
18010       !LD->getValueType(0).isInteger())
18011     return false;
18012 
18013   // The algorithm to split up a load of a scalable vector into individual
18014   // elements currently requires knowing the length of the loaded type,
18015   // so will need adjusting to work on scalable vectors.
18016   if (LD->getValueType(0).isScalableVector())
18017     return false;
18018 
18019   // Keep track of already used bits to detect overlapping values.
18020   // In that case, we will just abort the transformation.
18021   APInt UsedBits(LD->getValueSizeInBits(0), 0);
18022 
18023   SmallVector<LoadedSlice, 4> LoadedSlices;
18024 
18025   // Check if this load is used as several smaller chunks of bits.
18026   // Basically, look for uses in trunc or trunc(lshr) and record a new chain
18027   // of computation for each trunc.
18028   for (SDNode::use_iterator UI = LD->use_begin(), UIEnd = LD->use_end();
18029        UI != UIEnd; ++UI) {
18030     // Skip the uses of the chain.
18031     if (UI.getUse().getResNo() != 0)
18032       continue;
18033 
18034     SDNode *User = *UI;
18035     unsigned Shift = 0;
18036 
18037     // Check if this is a trunc(lshr).
18038     if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
18039         isa<ConstantSDNode>(User->getOperand(1))) {
18040       Shift = User->getConstantOperandVal(1);
18041       User = *User->use_begin();
18042     }
18043 
18044     // At this point, User is a Truncate, iff we encountered, trunc or
18045     // trunc(lshr).
18046     if (User->getOpcode() != ISD::TRUNCATE)
18047       return false;
18048 
18049     // The width of the type must be a power of 2 and greater than 8-bits.
18050     // Otherwise the load cannot be represented in LLVM IR.
18051     // Moreover, if we shifted with a non-8-bits multiple, the slice
18052     // will be across several bytes. We do not support that.
18053     unsigned Width = User->getValueSizeInBits(0);
18054     if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7))
18055       return false;
18056 
18057     // Build the slice for this chain of computations.
18058     LoadedSlice LS(User, LD, Shift, &DAG);
18059     APInt CurrentUsedBits = LS.getUsedBits();
18060 
18061     // Check if this slice overlaps with another.
18062     if ((CurrentUsedBits & UsedBits) != 0)
18063       return false;
18064     // Update the bits used globally.
18065     UsedBits |= CurrentUsedBits;
18066 
18067     // Check if the new slice would be legal.
18068     if (!LS.isLegal())
18069       return false;
18070 
18071     // Record the slice.
18072     LoadedSlices.push_back(LS);
18073   }
18074 
18075   // Abort slicing if it does not seem to be profitable.
18076   if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
18077     return false;
18078 
18079   ++SlicedLoads;
18080 
18081   // Rewrite each chain to use an independent load.
18082   // By construction, each chain can be represented by a unique load.
18083 
18084   // Prepare the argument for the new token factor for all the slices.
18085   SmallVector<SDValue, 8> ArgChains;
18086   for (const LoadedSlice &LS : LoadedSlices) {
18087     SDValue SliceInst = LS.loadSlice();
18088     CombineTo(LS.Inst, SliceInst, true);
18089     if (SliceInst.getOpcode() != ISD::LOAD)
18090       SliceInst = SliceInst.getOperand(0);
18091     assert(SliceInst->getOpcode() == ISD::LOAD &&
18092            "It takes more than a zext to get to the loaded slice!!");
18093     ArgChains.push_back(SliceInst.getValue(1));
18094   }
18095 
18096   SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
18097                               ArgChains);
18098   DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
18099   AddToWorklist(Chain.getNode());
18100   return true;
18101 }
18102 
18103 /// Check to see if V is (and load (ptr), imm), where the load is having
18104 /// specific bytes cleared out.  If so, return the byte size being masked out
18105 /// and the shift amount.
18106 static std::pair<unsigned, unsigned>
CheckForMaskedLoad(SDValue V,SDValue Ptr,SDValue Chain)18107 CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
18108   std::pair<unsigned, unsigned> Result(0, 0);
18109 
18110   // Check for the structure we're looking for.
18111   if (V->getOpcode() != ISD::AND ||
18112       !isa<ConstantSDNode>(V->getOperand(1)) ||
18113       !ISD::isNormalLoad(V->getOperand(0).getNode()))
18114     return Result;
18115 
18116   // Check the chain and pointer.
18117   LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
18118   if (LD->getBasePtr() != Ptr) return Result;  // Not from same pointer.
18119 
18120   // This only handles simple types.
18121   if (V.getValueType() != MVT::i16 &&
18122       V.getValueType() != MVT::i32 &&
18123       V.getValueType() != MVT::i64)
18124     return Result;
18125 
18126   // Check the constant mask.  Invert it so that the bits being masked out are
18127   // 0 and the bits being kept are 1.  Use getSExtValue so that leading bits
18128   // follow the sign bit for uniformity.
18129   uint64_t NotMask = ~cast<ConstantSDNode>(V->getOperand(1))->getSExtValue();
18130   unsigned NotMaskLZ = countLeadingZeros(NotMask);
18131   if (NotMaskLZ & 7) return Result;  // Must be multiple of a byte.
18132   unsigned NotMaskTZ = countTrailingZeros(NotMask);
18133   if (NotMaskTZ & 7) return Result;  // Must be multiple of a byte.
18134   if (NotMaskLZ == 64) return Result;  // All zero mask.
18135 
18136   // See if we have a continuous run of bits.  If so, we have 0*1+0*
18137   if (countTrailingOnes(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
18138     return Result;
18139 
18140   // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
18141   if (V.getValueType() != MVT::i64 && NotMaskLZ)
18142     NotMaskLZ -= 64-V.getValueSizeInBits();
18143 
18144   unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
18145   switch (MaskedBytes) {
18146   case 1:
18147   case 2:
18148   case 4: break;
18149   default: return Result; // All one mask, or 5-byte mask.
18150   }
18151 
18152   // Verify that the first bit starts at a multiple of mask so that the access
18153   // is aligned the same as the access width.
18154   if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
18155 
18156   // For narrowing to be valid, it must be the case that the load the
18157   // immediately preceding memory operation before the store.
18158   if (LD == Chain.getNode())
18159     ; // ok.
18160   else if (Chain->getOpcode() == ISD::TokenFactor &&
18161            SDValue(LD, 1).hasOneUse()) {
18162     // LD has only 1 chain use so they are no indirect dependencies.
18163     if (!LD->isOperandOf(Chain.getNode()))
18164       return Result;
18165   } else
18166     return Result; // Fail.
18167 
18168   Result.first = MaskedBytes;
18169   Result.second = NotMaskTZ/8;
18170   return Result;
18171 }
18172 
18173 /// Check to see if IVal is something that provides a value as specified by
18174 /// MaskInfo. If so, replace the specified store with a narrower store of
18175 /// truncated IVal.
18176 static SDValue
ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned,unsigned> & MaskInfo,SDValue IVal,StoreSDNode * St,DAGCombiner * DC)18177 ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
18178                                 SDValue IVal, StoreSDNode *St,
18179                                 DAGCombiner *DC) {
18180   unsigned NumBytes = MaskInfo.first;
18181   unsigned ByteShift = MaskInfo.second;
18182   SelectionDAG &DAG = DC->getDAG();
18183 
18184   // Check to see if IVal is all zeros in the part being masked in by the 'or'
18185   // that uses this.  If not, this is not a replacement.
18186   APInt Mask = ~APInt::getBitsSet(IVal.getValueSizeInBits(),
18187                                   ByteShift*8, (ByteShift+NumBytes)*8);
18188   if (!DAG.MaskedValueIsZero(IVal, Mask)) return SDValue();
18189 
18190   // Check that it is legal on the target to do this.  It is legal if the new
18191   // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
18192   // legalization. If the source type is legal, but the store type isn't, see
18193   // if we can use a truncating store.
18194   MVT VT = MVT::getIntegerVT(NumBytes * 8);
18195   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18196   bool UseTruncStore;
18197   if (DC->isTypeLegal(VT))
18198     UseTruncStore = false;
18199   else if (TLI.isTypeLegal(IVal.getValueType()) &&
18200            TLI.isTruncStoreLegal(IVal.getValueType(), VT))
18201     UseTruncStore = true;
18202   else
18203     return SDValue();
18204   // Check that the target doesn't think this is a bad idea.
18205   if (St->getMemOperand() &&
18206       !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
18207                               *St->getMemOperand()))
18208     return SDValue();
18209 
18210   // Okay, we can do this!  Replace the 'St' store with a store of IVal that is
18211   // shifted by ByteShift and truncated down to NumBytes.
18212   if (ByteShift) {
18213     SDLoc DL(IVal);
18214     IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal,
18215                        DAG.getConstant(ByteShift*8, DL,
18216                                     DC->getShiftAmountTy(IVal.getValueType())));
18217   }
18218 
18219   // Figure out the offset for the store and the alignment of the access.
18220   unsigned StOffset;
18221   if (DAG.getDataLayout().isLittleEndian())
18222     StOffset = ByteShift;
18223   else
18224     StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
18225 
18226   SDValue Ptr = St->getBasePtr();
18227   if (StOffset) {
18228     SDLoc DL(IVal);
18229     Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(StOffset), DL);
18230   }
18231 
18232   ++OpsNarrowed;
18233   if (UseTruncStore)
18234     return DAG.getTruncStore(St->getChain(), SDLoc(St), IVal, Ptr,
18235                              St->getPointerInfo().getWithOffset(StOffset),
18236                              VT, St->getOriginalAlign());
18237 
18238   // Truncate down to the new size.
18239   IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal);
18240 
18241   return DAG
18242       .getStore(St->getChain(), SDLoc(St), IVal, Ptr,
18243                 St->getPointerInfo().getWithOffset(StOffset),
18244                 St->getOriginalAlign());
18245 }
18246 
18247 /// Look for sequence of load / op / store where op is one of 'or', 'xor', and
18248 /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
18249 /// narrowing the load and store if it would end up being a win for performance
18250 /// or code size.
ReduceLoadOpStoreWidth(SDNode * N)18251 SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
18252   StoreSDNode *ST  = cast<StoreSDNode>(N);
18253   if (!ST->isSimple())
18254     return SDValue();
18255 
18256   SDValue Chain = ST->getChain();
18257   SDValue Value = ST->getValue();
18258   SDValue Ptr   = ST->getBasePtr();
18259   EVT VT = Value.getValueType();
18260 
18261   if (ST->isTruncatingStore() || VT.isVector())
18262     return SDValue();
18263 
18264   unsigned Opc = Value.getOpcode();
18265 
18266   if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
18267       !Value.hasOneUse())
18268     return SDValue();
18269 
18270   // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
18271   // is a byte mask indicating a consecutive number of bytes, check to see if
18272   // Y is known to provide just those bytes.  If so, we try to replace the
18273   // load + replace + store sequence with a single (narrower) store, which makes
18274   // the load dead.
18275   if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
18276     std::pair<unsigned, unsigned> MaskedLoad;
18277     MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain);
18278     if (MaskedLoad.first)
18279       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
18280                                                   Value.getOperand(1), ST,this))
18281         return NewST;
18282 
18283     // Or is commutative, so try swapping X and Y.
18284     MaskedLoad = CheckForMaskedLoad(Value.getOperand(1), Ptr, Chain);
18285     if (MaskedLoad.first)
18286       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
18287                                                   Value.getOperand(0), ST,this))
18288         return NewST;
18289   }
18290 
18291   if (!EnableReduceLoadOpStoreWidth)
18292     return SDValue();
18293 
18294   if (Value.getOperand(1).getOpcode() != ISD::Constant)
18295     return SDValue();
18296 
18297   SDValue N0 = Value.getOperand(0);
18298   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
18299       Chain == SDValue(N0.getNode(), 1)) {
18300     LoadSDNode *LD = cast<LoadSDNode>(N0);
18301     if (LD->getBasePtr() != Ptr ||
18302         LD->getPointerInfo().getAddrSpace() !=
18303         ST->getPointerInfo().getAddrSpace())
18304       return SDValue();
18305 
18306     // Find the type to narrow it the load / op / store to.
18307     SDValue N1 = Value.getOperand(1);
18308     unsigned BitWidth = N1.getValueSizeInBits();
18309     APInt Imm = cast<ConstantSDNode>(N1)->getAPIntValue();
18310     if (Opc == ISD::AND)
18311       Imm ^= APInt::getAllOnes(BitWidth);
18312     if (Imm == 0 || Imm.isAllOnes())
18313       return SDValue();
18314     unsigned ShAmt = Imm.countTrailingZeros();
18315     unsigned MSB = BitWidth - Imm.countLeadingZeros() - 1;
18316     unsigned NewBW = NextPowerOf2(MSB - ShAmt);
18317     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
18318     // The narrowing should be profitable, the load/store operation should be
18319     // legal (or custom) and the store size should be equal to the NewVT width.
18320     while (NewBW < BitWidth &&
18321            (NewVT.getStoreSizeInBits() != NewBW ||
18322             !TLI.isOperationLegalOrCustom(Opc, NewVT) ||
18323             !TLI.isNarrowingProfitable(VT, NewVT))) {
18324       NewBW = NextPowerOf2(NewBW);
18325       NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
18326     }
18327     if (NewBW >= BitWidth)
18328       return SDValue();
18329 
18330     // If the lsb changed does not start at the type bitwidth boundary,
18331     // start at the previous one.
18332     if (ShAmt % NewBW)
18333       ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
18334     APInt Mask = APInt::getBitsSet(BitWidth, ShAmt,
18335                                    std::min(BitWidth, ShAmt + NewBW));
18336     if ((Imm & Mask) == Imm) {
18337       APInt NewImm = (Imm & Mask).lshr(ShAmt).trunc(NewBW);
18338       if (Opc == ISD::AND)
18339         NewImm ^= APInt::getAllOnes(NewBW);
18340       uint64_t PtrOff = ShAmt / 8;
18341       // For big endian targets, we need to adjust the offset to the pointer to
18342       // load the correct bytes.
18343       if (DAG.getDataLayout().isBigEndian())
18344         PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff;
18345 
18346       unsigned IsFast = 0;
18347       Align NewAlign = commonAlignment(LD->getAlign(), PtrOff);
18348       if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), NewVT,
18349                                   LD->getAddressSpace(), NewAlign,
18350                                   LD->getMemOperand()->getFlags(), &IsFast) ||
18351           !IsFast)
18352         return SDValue();
18353 
18354       SDValue NewPtr =
18355           DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(PtrOff), SDLoc(LD));
18356       SDValue NewLD =
18357           DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr,
18358                       LD->getPointerInfo().getWithOffset(PtrOff), NewAlign,
18359                       LD->getMemOperand()->getFlags(), LD->getAAInfo());
18360       SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD,
18361                                    DAG.getConstant(NewImm, SDLoc(Value),
18362                                                    NewVT));
18363       SDValue NewST =
18364           DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr,
18365                        ST->getPointerInfo().getWithOffset(PtrOff), NewAlign);
18366 
18367       AddToWorklist(NewPtr.getNode());
18368       AddToWorklist(NewLD.getNode());
18369       AddToWorklist(NewVal.getNode());
18370       WorklistRemover DeadNodes(*this);
18371       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLD.getValue(1));
18372       ++OpsNarrowed;
18373       return NewST;
18374     }
18375   }
18376 
18377   return SDValue();
18378 }
18379 
18380 /// For a given floating point load / store pair, if the load value isn't used
18381 /// by any other operations, then consider transforming the pair to integer
18382 /// load / store operations if the target deems the transformation profitable.
TransformFPLoadStorePair(SDNode * N)18383 SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
18384   StoreSDNode *ST  = cast<StoreSDNode>(N);
18385   SDValue Value = ST->getValue();
18386   if (ISD::isNormalStore(ST) && ISD::isNormalLoad(Value.getNode()) &&
18387       Value.hasOneUse()) {
18388     LoadSDNode *LD = cast<LoadSDNode>(Value);
18389     EVT VT = LD->getMemoryVT();
18390     if (!VT.isFloatingPoint() ||
18391         VT != ST->getMemoryVT() ||
18392         LD->isNonTemporal() ||
18393         ST->isNonTemporal() ||
18394         LD->getPointerInfo().getAddrSpace() != 0 ||
18395         ST->getPointerInfo().getAddrSpace() != 0)
18396       return SDValue();
18397 
18398     TypeSize VTSize = VT.getSizeInBits();
18399 
18400     // We don't know the size of scalable types at compile time so we cannot
18401     // create an integer of the equivalent size.
18402     if (VTSize.isScalable())
18403       return SDValue();
18404 
18405     unsigned FastLD = 0, FastST = 0;
18406     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VTSize.getFixedValue());
18407     if (!TLI.isOperationLegal(ISD::LOAD, IntVT) ||
18408         !TLI.isOperationLegal(ISD::STORE, IntVT) ||
18409         !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
18410         !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT) ||
18411         !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
18412                                 *LD->getMemOperand(), &FastLD) ||
18413         !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
18414                                 *ST->getMemOperand(), &FastST) ||
18415         !FastLD || !FastST)
18416       return SDValue();
18417 
18418     SDValue NewLD =
18419         DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(), LD->getBasePtr(),
18420                     LD->getPointerInfo(), LD->getAlign());
18421 
18422     SDValue NewST =
18423         DAG.getStore(ST->getChain(), SDLoc(N), NewLD, ST->getBasePtr(),
18424                      ST->getPointerInfo(), ST->getAlign());
18425 
18426     AddToWorklist(NewLD.getNode());
18427     AddToWorklist(NewST.getNode());
18428     WorklistRemover DeadNodes(*this);
18429     DAG.ReplaceAllUsesOfValueWith(Value.getValue(1), NewLD.getValue(1));
18430     ++LdStFP2Int;
18431     return NewST;
18432   }
18433 
18434   return SDValue();
18435 }
18436 
18437 // This is a helper function for visitMUL to check the profitability
18438 // of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
18439 // MulNode is the original multiply, AddNode is (add x, c1),
18440 // and ConstNode is c2.
18441 //
18442 // If the (add x, c1) has multiple uses, we could increase
18443 // the number of adds if we make this transformation.
18444 // It would only be worth doing this if we can remove a
18445 // multiply in the process. Check for that here.
18446 // To illustrate:
18447 //     (A + c1) * c3
18448 //     (A + c2) * c3
18449 // We're checking for cases where we have common "c3 * A" expressions.
isMulAddWithConstProfitable(SDNode * MulNode,SDValue AddNode,SDValue ConstNode)18450 bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
18451                                               SDValue ConstNode) {
18452   APInt Val;
18453 
18454   // If the add only has one use, and the target thinks the folding is
18455   // profitable or does not lead to worse code, this would be OK to do.
18456   if (AddNode->hasOneUse() &&
18457       TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
18458     return true;
18459 
18460   // Walk all the users of the constant with which we're multiplying.
18461   for (SDNode *Use : ConstNode->uses()) {
18462     if (Use == MulNode) // This use is the one we're on right now. Skip it.
18463       continue;
18464 
18465     if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
18466       SDNode *OtherOp;
18467       SDNode *MulVar = AddNode.getOperand(0).getNode();
18468 
18469       // OtherOp is what we're multiplying against the constant.
18470       if (Use->getOperand(0) == ConstNode)
18471         OtherOp = Use->getOperand(1).getNode();
18472       else
18473         OtherOp = Use->getOperand(0).getNode();
18474 
18475       // Check to see if multiply is with the same operand of our "add".
18476       //
18477       //     ConstNode  = CONST
18478       //     Use = ConstNode * A  <-- visiting Use. OtherOp is A.
18479       //     ...
18480       //     AddNode  = (A + c1)  <-- MulVar is A.
18481       //         = AddNode * ConstNode   <-- current visiting instruction.
18482       //
18483       // If we make this transformation, we will have a common
18484       // multiply (ConstNode * A) that we can save.
18485       if (OtherOp == MulVar)
18486         return true;
18487 
18488       // Now check to see if a future expansion will give us a common
18489       // multiply.
18490       //
18491       //     ConstNode  = CONST
18492       //     AddNode    = (A + c1)
18493       //     ...   = AddNode * ConstNode <-- current visiting instruction.
18494       //     ...
18495       //     OtherOp = (A + c2)
18496       //     Use     = OtherOp * ConstNode <-- visiting Use.
18497       //
18498       // If we make this transformation, we will have a common
18499       // multiply (CONST * A) after we also do the same transformation
18500       // to the "t2" instruction.
18501       if (OtherOp->getOpcode() == ISD::ADD &&
18502           DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
18503           OtherOp->getOperand(0).getNode() == MulVar)
18504         return true;
18505     }
18506   }
18507 
18508   // Didn't find a case where this would be profitable.
18509   return false;
18510 }
18511 
getMergeStoreChains(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores)18512 SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
18513                                          unsigned NumStores) {
18514   SmallVector<SDValue, 8> Chains;
18515   SmallPtrSet<const SDNode *, 8> Visited;
18516   SDLoc StoreDL(StoreNodes[0].MemNode);
18517 
18518   for (unsigned i = 0; i < NumStores; ++i) {
18519     Visited.insert(StoreNodes[i].MemNode);
18520   }
18521 
18522   // don't include nodes that are children or repeated nodes.
18523   for (unsigned i = 0; i < NumStores; ++i) {
18524     if (Visited.insert(StoreNodes[i].MemNode->getChain().getNode()).second)
18525       Chains.push_back(StoreNodes[i].MemNode->getChain());
18526   }
18527 
18528   assert(Chains.size() > 0 && "Chain should have generated a chain");
18529   return DAG.getTokenFactor(StoreDL, Chains);
18530 }
18531 
mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> & StoreNodes,EVT MemVT,unsigned NumStores,bool IsConstantSrc,bool UseVector,bool UseTrunc)18532 bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
18533     SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
18534     bool IsConstantSrc, bool UseVector, bool UseTrunc) {
18535   // Make sure we have something to merge.
18536   if (NumStores < 2)
18537     return false;
18538 
18539   assert((!UseTrunc || !UseVector) &&
18540          "This optimization cannot emit a vector truncating store");
18541 
18542   // The latest Node in the DAG.
18543   SDLoc DL(StoreNodes[0].MemNode);
18544 
18545   TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
18546   unsigned SizeInBits = NumStores * ElementSizeBits;
18547   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
18548 
18549   std::optional<MachineMemOperand::Flags> Flags;
18550   AAMDNodes AAInfo;
18551   for (unsigned I = 0; I != NumStores; ++I) {
18552     StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
18553     if (!Flags) {
18554       Flags = St->getMemOperand()->getFlags();
18555       AAInfo = St->getAAInfo();
18556       continue;
18557     }
18558     // Skip merging if there's an inconsistent flag.
18559     if (Flags != St->getMemOperand()->getFlags())
18560       return false;
18561     // Concatenate AA metadata.
18562     AAInfo = AAInfo.concat(St->getAAInfo());
18563   }
18564 
18565   EVT StoreTy;
18566   if (UseVector) {
18567     unsigned Elts = NumStores * NumMemElts;
18568     // Get the type for the merged vector store.
18569     StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
18570   } else
18571     StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
18572 
18573   SDValue StoredVal;
18574   if (UseVector) {
18575     if (IsConstantSrc) {
18576       SmallVector<SDValue, 8> BuildVector;
18577       for (unsigned I = 0; I != NumStores; ++I) {
18578         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
18579         SDValue Val = St->getValue();
18580         // If constant is of the wrong type, convert it now.
18581         if (MemVT != Val.getValueType()) {
18582           Val = peekThroughBitcasts(Val);
18583           // Deal with constants of wrong size.
18584           if (ElementSizeBits != Val.getValueSizeInBits()) {
18585             EVT IntMemVT =
18586                 EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
18587             if (isa<ConstantFPSDNode>(Val)) {
18588               // Not clear how to truncate FP values.
18589               return false;
18590             }
18591 
18592             if (auto *C = dyn_cast<ConstantSDNode>(Val))
18593               Val = DAG.getConstant(C->getAPIntValue()
18594                                         .zextOrTrunc(Val.getValueSizeInBits())
18595                                         .zextOrTrunc(ElementSizeBits),
18596                                     SDLoc(C), IntMemVT);
18597           }
18598           // Make sure correctly size type is the correct type.
18599           Val = DAG.getBitcast(MemVT, Val);
18600         }
18601         BuildVector.push_back(Val);
18602       }
18603       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
18604                                                : ISD::BUILD_VECTOR,
18605                               DL, StoreTy, BuildVector);
18606     } else {
18607       SmallVector<SDValue, 8> Ops;
18608       for (unsigned i = 0; i < NumStores; ++i) {
18609         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
18610         SDValue Val = peekThroughBitcasts(St->getValue());
18611         // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
18612         // type MemVT. If the underlying value is not the correct
18613         // type, but it is an extraction of an appropriate vector we
18614         // can recast Val to be of the correct type. This may require
18615         // converting between EXTRACT_VECTOR_ELT and
18616         // EXTRACT_SUBVECTOR.
18617         if ((MemVT != Val.getValueType()) &&
18618             (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
18619              Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
18620           EVT MemVTScalarTy = MemVT.getScalarType();
18621           // We may need to add a bitcast here to get types to line up.
18622           if (MemVTScalarTy != Val.getValueType().getScalarType()) {
18623             Val = DAG.getBitcast(MemVT, Val);
18624           } else if (MemVT.isVector() &&
18625                      Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
18626             Val = DAG.getNode(ISD::BUILD_VECTOR, DL, MemVT, Val);
18627           } else {
18628             unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
18629                                             : ISD::EXTRACT_VECTOR_ELT;
18630             SDValue Vec = Val.getOperand(0);
18631             SDValue Idx = Val.getOperand(1);
18632             Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx);
18633           }
18634         }
18635         Ops.push_back(Val);
18636       }
18637 
18638       // Build the extracted vector elements back into a vector.
18639       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
18640                                                : ISD::BUILD_VECTOR,
18641                               DL, StoreTy, Ops);
18642     }
18643   } else {
18644     // We should always use a vector store when merging extracted vector
18645     // elements, so this path implies a store of constants.
18646     assert(IsConstantSrc && "Merged vector elements should use vector store");
18647 
18648     APInt StoreInt(SizeInBits, 0);
18649 
18650     // Construct a single integer constant which is made of the smaller
18651     // constant inputs.
18652     bool IsLE = DAG.getDataLayout().isLittleEndian();
18653     for (unsigned i = 0; i < NumStores; ++i) {
18654       unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
18655       StoreSDNode *St  = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
18656 
18657       SDValue Val = St->getValue();
18658       Val = peekThroughBitcasts(Val);
18659       StoreInt <<= ElementSizeBits;
18660       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
18661         StoreInt |= C->getAPIntValue()
18662                         .zextOrTrunc(ElementSizeBits)
18663                         .zextOrTrunc(SizeInBits);
18664       } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
18665         StoreInt |= C->getValueAPF()
18666                         .bitcastToAPInt()
18667                         .zextOrTrunc(ElementSizeBits)
18668                         .zextOrTrunc(SizeInBits);
18669         // If fp truncation is necessary give up for now.
18670         if (MemVT.getSizeInBits() != ElementSizeBits)
18671           return false;
18672       } else {
18673         llvm_unreachable("Invalid constant element type");
18674       }
18675     }
18676 
18677     // Create the new Load and Store operations.
18678     StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
18679   }
18680 
18681   LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
18682   SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
18683 
18684   // make sure we use trunc store if it's necessary to be legal.
18685   SDValue NewStore;
18686   if (!UseTrunc) {
18687     NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
18688                             FirstInChain->getPointerInfo(),
18689                             FirstInChain->getAlign(), *Flags, AAInfo);
18690   } else { // Must be realized as a trunc store
18691     EVT LegalizedStoredValTy =
18692         TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
18693     unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
18694     ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
18695     SDValue ExtendedStoreVal =
18696         DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
18697                         LegalizedStoredValTy);
18698     NewStore = DAG.getTruncStore(
18699         NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
18700         FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/,
18701         FirstInChain->getAlign(), *Flags, AAInfo);
18702   }
18703 
18704   // Replace all merged stores with the new store.
18705   for (unsigned i = 0; i < NumStores; ++i)
18706     CombineTo(StoreNodes[i].MemNode, NewStore);
18707 
18708   AddToWorklist(NewChain.getNode());
18709   return true;
18710 }
18711 
getStoreMergeCandidates(StoreSDNode * St,SmallVectorImpl<MemOpLink> & StoreNodes,SDNode * & RootNode)18712 void DAGCombiner::getStoreMergeCandidates(
18713     StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes,
18714     SDNode *&RootNode) {
18715   // This holds the base pointer, index, and the offset in bytes from the base
18716   // pointer. We must have a base and an offset. Do not handle stores to undef
18717   // base pointers.
18718   BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
18719   if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
18720     return;
18721 
18722   SDValue Val = peekThroughBitcasts(St->getValue());
18723   StoreSource StoreSrc = getStoreSource(Val);
18724   assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
18725 
18726   // Match on loadbaseptr if relevant.
18727   EVT MemVT = St->getMemoryVT();
18728   BaseIndexOffset LBasePtr;
18729   EVT LoadVT;
18730   if (StoreSrc == StoreSource::Load) {
18731     auto *Ld = cast<LoadSDNode>(Val);
18732     LBasePtr = BaseIndexOffset::match(Ld, DAG);
18733     LoadVT = Ld->getMemoryVT();
18734     // Load and store should be the same type.
18735     if (MemVT != LoadVT)
18736       return;
18737     // Loads must only have one use.
18738     if (!Ld->hasNUsesOfValue(1, 0))
18739       return;
18740     // The memory operands must not be volatile/indexed/atomic.
18741     // TODO: May be able to relax for unordered atomics (see D66309)
18742     if (!Ld->isSimple() || Ld->isIndexed())
18743       return;
18744   }
18745   auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
18746                             int64_t &Offset) -> bool {
18747     // The memory operands must not be volatile/indexed/atomic.
18748     // TODO: May be able to relax for unordered atomics (see D66309)
18749     if (!Other->isSimple() || Other->isIndexed())
18750       return false;
18751     // Don't mix temporal stores with non-temporal stores.
18752     if (St->isNonTemporal() != Other->isNonTemporal())
18753       return false;
18754     SDValue OtherBC = peekThroughBitcasts(Other->getValue());
18755     // Allow merging constants of different types as integers.
18756     bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
18757                                            : Other->getMemoryVT() != MemVT;
18758     switch (StoreSrc) {
18759     case StoreSource::Load: {
18760       if (NoTypeMatch)
18761         return false;
18762       // The Load's Base Ptr must also match.
18763       auto *OtherLd = dyn_cast<LoadSDNode>(OtherBC);
18764       if (!OtherLd)
18765         return false;
18766       BaseIndexOffset LPtr = BaseIndexOffset::match(OtherLd, DAG);
18767       if (LoadVT != OtherLd->getMemoryVT())
18768         return false;
18769       // Loads must only have one use.
18770       if (!OtherLd->hasNUsesOfValue(1, 0))
18771         return false;
18772       // The memory operands must not be volatile/indexed/atomic.
18773       // TODO: May be able to relax for unordered atomics (see D66309)
18774       if (!OtherLd->isSimple() || OtherLd->isIndexed())
18775         return false;
18776       // Don't mix temporal loads with non-temporal loads.
18777       if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
18778         return false;
18779       if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
18780         return false;
18781       break;
18782     }
18783     case StoreSource::Constant:
18784       if (NoTypeMatch)
18785         return false;
18786       if (!isIntOrFPConstant(OtherBC))
18787         return false;
18788       break;
18789     case StoreSource::Extract:
18790       // Do not merge truncated stores here.
18791       if (Other->isTruncatingStore())
18792         return false;
18793       if (!MemVT.bitsEq(OtherBC.getValueType()))
18794         return false;
18795       if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
18796           OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
18797         return false;
18798       break;
18799     default:
18800       llvm_unreachable("Unhandled store source for merging");
18801     }
18802     Ptr = BaseIndexOffset::match(Other, DAG);
18803     return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
18804   };
18805 
18806   // Check if the pair of StoreNode and the RootNode already bail out many
18807   // times which is over the limit in dependence check.
18808   auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
18809                                         SDNode *RootNode) -> bool {
18810     auto RootCount = StoreRootCountMap.find(StoreNode);
18811     return RootCount != StoreRootCountMap.end() &&
18812            RootCount->second.first == RootNode &&
18813            RootCount->second.second > StoreMergeDependenceLimit;
18814   };
18815 
18816   auto TryToAddCandidate = [&](SDNode::use_iterator UseIter) {
18817     // This must be a chain use.
18818     if (UseIter.getOperandNo() != 0)
18819       return;
18820     if (auto *OtherStore = dyn_cast<StoreSDNode>(*UseIter)) {
18821       BaseIndexOffset Ptr;
18822       int64_t PtrDiff;
18823       if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
18824           !OverLimitInDependenceCheck(OtherStore, RootNode))
18825         StoreNodes.push_back(MemOpLink(OtherStore, PtrDiff));
18826     }
18827   };
18828 
18829   // We looking for a root node which is an ancestor to all mergable
18830   // stores. We search up through a load, to our root and then down
18831   // through all children. For instance we will find Store{1,2,3} if
18832   // St is Store1, Store2. or Store3 where the root is not a load
18833   // which always true for nonvolatile ops. TODO: Expand
18834   // the search to find all valid candidates through multiple layers of loads.
18835   //
18836   // Root
18837   // |-------|-------|
18838   // Load    Load    Store3
18839   // |       |
18840   // Store1   Store2
18841   //
18842   // FIXME: We should be able to climb and
18843   // descend TokenFactors to find candidates as well.
18844 
18845   RootNode = St->getChain().getNode();
18846 
18847   unsigned NumNodesExplored = 0;
18848   const unsigned MaxSearchNodes = 1024;
18849   if (auto *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
18850     RootNode = Ldn->getChain().getNode();
18851     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
18852          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
18853       if (I.getOperandNo() == 0 && isa<LoadSDNode>(*I)) { // walk down chain
18854         for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
18855           TryToAddCandidate(I2);
18856       }
18857       // Check stores that depend on the root (e.g. Store 3 in the chart above).
18858       if (I.getOperandNo() == 0 && isa<StoreSDNode>(*I)) {
18859         TryToAddCandidate(I);
18860       }
18861     }
18862   } else {
18863     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
18864          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
18865       TryToAddCandidate(I);
18866   }
18867 }
18868 
18869 // We need to check that merging these stores does not cause a loop in the
18870 // DAG. Any store candidate may depend on another candidate indirectly through
18871 // its operands. Check in parallel by searching up from operands of candidates.
checkMergeStoreCandidatesForDependencies(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores,SDNode * RootNode)18872 bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
18873     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
18874     SDNode *RootNode) {
18875   // FIXME: We should be able to truncate a full search of
18876   // predecessors by doing a BFS and keeping tabs the originating
18877   // stores from which worklist nodes come from in a similar way to
18878   // TokenFactor simplfication.
18879 
18880   SmallPtrSet<const SDNode *, 32> Visited;
18881   SmallVector<const SDNode *, 8> Worklist;
18882 
18883   // RootNode is a predecessor to all candidates so we need not search
18884   // past it. Add RootNode (peeking through TokenFactors). Do not count
18885   // these towards size check.
18886 
18887   Worklist.push_back(RootNode);
18888   while (!Worklist.empty()) {
18889     auto N = Worklist.pop_back_val();
18890     if (!Visited.insert(N).second)
18891       continue; // Already present in Visited.
18892     if (N->getOpcode() == ISD::TokenFactor) {
18893       for (SDValue Op : N->ops())
18894         Worklist.push_back(Op.getNode());
18895     }
18896   }
18897 
18898   // Don't count pruning nodes towards max.
18899   unsigned int Max = 1024 + Visited.size();
18900   // Search Ops of store candidates.
18901   for (unsigned i = 0; i < NumStores; ++i) {
18902     SDNode *N = StoreNodes[i].MemNode;
18903     // Of the 4 Store Operands:
18904     //   * Chain (Op 0) -> We have already considered these
18905     //                     in candidate selection, but only by following the
18906     //                     chain dependencies. We could still have a chain
18907     //                     dependency to a load, that has a non-chain dep to
18908     //                     another load, that depends on a store, etc. So it is
18909     //                     possible to have dependencies that consist of a mix
18910     //                     of chain and non-chain deps, and we need to include
18911     //                     chain operands in the analysis here..
18912     //   * Value (Op 1) -> Cycles may happen (e.g. through load chains)
18913     //   * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
18914     //                       but aren't necessarily fromt the same base node, so
18915     //                       cycles possible (e.g. via indexed store).
18916     //   * (Op 3) -> Represents the pre or post-indexing offset (or undef for
18917     //               non-indexed stores). Not constant on all targets (e.g. ARM)
18918     //               and so can participate in a cycle.
18919     for (unsigned j = 0; j < N->getNumOperands(); ++j)
18920       Worklist.push_back(N->getOperand(j).getNode());
18921   }
18922   // Search through DAG. We can stop early if we find a store node.
18923   for (unsigned i = 0; i < NumStores; ++i)
18924     if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
18925                                      Max)) {
18926       // If the searching bail out, record the StoreNode and RootNode in the
18927       // StoreRootCountMap. If we have seen the pair many times over a limit,
18928       // we won't add the StoreNode into StoreNodes set again.
18929       if (Visited.size() >= Max) {
18930         auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
18931         if (RootCount.first == RootNode)
18932           RootCount.second++;
18933         else
18934           RootCount = {RootNode, 1};
18935       }
18936       return false;
18937     }
18938   return true;
18939 }
18940 
18941 unsigned
getConsecutiveStores(SmallVectorImpl<MemOpLink> & StoreNodes,int64_t ElementSizeBytes) const18942 DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
18943                                   int64_t ElementSizeBytes) const {
18944   while (true) {
18945     // Find a store past the width of the first store.
18946     size_t StartIdx = 0;
18947     while ((StartIdx + 1 < StoreNodes.size()) &&
18948            StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
18949               StoreNodes[StartIdx + 1].OffsetFromBase)
18950       ++StartIdx;
18951 
18952     // Bail if we don't have enough candidates to merge.
18953     if (StartIdx + 1 >= StoreNodes.size())
18954       return 0;
18955 
18956     // Trim stores that overlapped with the first store.
18957     if (StartIdx)
18958       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
18959 
18960     // Scan the memory operations on the chain and find the first
18961     // non-consecutive store memory address.
18962     unsigned NumConsecutiveStores = 1;
18963     int64_t StartAddress = StoreNodes[0].OffsetFromBase;
18964     // Check that the addresses are consecutive starting from the second
18965     // element in the list of stores.
18966     for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
18967       int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
18968       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
18969         break;
18970       NumConsecutiveStores = i + 1;
18971     }
18972     if (NumConsecutiveStores > 1)
18973       return NumConsecutiveStores;
18974 
18975     // There are no consecutive stores at the start of the list.
18976     // Remove the first store and try again.
18977     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1);
18978   }
18979 }
18980 
tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors)18981 bool DAGCombiner::tryStoreMergeOfConstants(
18982     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
18983     EVT MemVT, SDNode *RootNode, bool AllowVectors) {
18984   LLVMContext &Context = *DAG.getContext();
18985   const DataLayout &DL = DAG.getDataLayout();
18986   int64_t ElementSizeBytes = MemVT.getStoreSize();
18987   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
18988   bool MadeChange = false;
18989 
18990   // Store the constants into memory as one consecutive store.
18991   while (NumConsecutiveStores >= 2) {
18992     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
18993     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
18994     Align FirstStoreAlign = FirstInChain->getAlign();
18995     unsigned LastLegalType = 1;
18996     unsigned LastLegalVectorType = 1;
18997     bool LastIntegerTrunc = false;
18998     bool NonZero = false;
18999     unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
19000     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
19001       StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
19002       SDValue StoredVal = ST->getValue();
19003       bool IsElementZero = false;
19004       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal))
19005         IsElementZero = C->isZero();
19006       else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal))
19007         IsElementZero = C->getConstantFPValue()->isNullValue();
19008       if (IsElementZero) {
19009         if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
19010           FirstZeroAfterNonZero = i;
19011       }
19012       NonZero |= !IsElementZero;
19013 
19014       // Find a legal type for the constant store.
19015       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
19016       EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
19017       unsigned IsFast = 0;
19018 
19019       // Break early when size is too large to be legal.
19020       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
19021         break;
19022 
19023       if (TLI.isTypeLegal(StoreTy) &&
19024           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
19025                                DAG.getMachineFunction()) &&
19026           TLI.allowsMemoryAccess(Context, DL, StoreTy,
19027                                  *FirstInChain->getMemOperand(), &IsFast) &&
19028           IsFast) {
19029         LastIntegerTrunc = false;
19030         LastLegalType = i + 1;
19031         // Or check whether a truncstore is legal.
19032       } else if (TLI.getTypeAction(Context, StoreTy) ==
19033                  TargetLowering::TypePromoteInteger) {
19034         EVT LegalizedStoredValTy =
19035             TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
19036         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
19037             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
19038                                  DAG.getMachineFunction()) &&
19039             TLI.allowsMemoryAccess(Context, DL, StoreTy,
19040                                    *FirstInChain->getMemOperand(), &IsFast) &&
19041             IsFast) {
19042           LastIntegerTrunc = true;
19043           LastLegalType = i + 1;
19044         }
19045       }
19046 
19047       // We only use vectors if the constant is known to be zero or the
19048       // target allows it and the function is not marked with the
19049       // noimplicitfloat attribute.
19050       if ((!NonZero ||
19051            TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) &&
19052           AllowVectors) {
19053         // Find a legal type for the vector store.
19054         unsigned Elts = (i + 1) * NumMemElts;
19055         EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
19056         if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) &&
19057             TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
19058             TLI.allowsMemoryAccess(Context, DL, Ty,
19059                                    *FirstInChain->getMemOperand(), &IsFast) &&
19060             IsFast)
19061           LastLegalVectorType = i + 1;
19062       }
19063     }
19064 
19065     bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
19066     unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
19067     bool UseTrunc = LastIntegerTrunc && !UseVector;
19068 
19069     // Check if we found a legal integer type that creates a meaningful
19070     // merge.
19071     if (NumElem < 2) {
19072       // We know that candidate stores are in order and of correct
19073       // shape. While there is no mergeable sequence from the
19074       // beginning one may start later in the sequence. The only
19075       // reason a merge of size N could have failed where another of
19076       // the same size would not have, is if the alignment has
19077       // improved or we've dropped a non-zero value. Drop as many
19078       // candidates as we can here.
19079       unsigned NumSkip = 1;
19080       while ((NumSkip < NumConsecutiveStores) &&
19081              (NumSkip < FirstZeroAfterNonZero) &&
19082              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
19083         NumSkip++;
19084 
19085       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
19086       NumConsecutiveStores -= NumSkip;
19087       continue;
19088     }
19089 
19090     // Check that we can merge these candidates without causing a cycle.
19091     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
19092                                                   RootNode)) {
19093       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
19094       NumConsecutiveStores -= NumElem;
19095       continue;
19096     }
19097 
19098     MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem,
19099                                                   /*IsConstantSrc*/ true,
19100                                                   UseVector, UseTrunc);
19101 
19102     // Remove merged stores for next iteration.
19103     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
19104     NumConsecutiveStores -= NumElem;
19105   }
19106   return MadeChange;
19107 }
19108 
tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode)19109 bool DAGCombiner::tryStoreMergeOfExtracts(
19110     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
19111     EVT MemVT, SDNode *RootNode) {
19112   LLVMContext &Context = *DAG.getContext();
19113   const DataLayout &DL = DAG.getDataLayout();
19114   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
19115   bool MadeChange = false;
19116 
19117   // Loop on Consecutive Stores on success.
19118   while (NumConsecutiveStores >= 2) {
19119     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
19120     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
19121     Align FirstStoreAlign = FirstInChain->getAlign();
19122     unsigned NumStoresToMerge = 1;
19123     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
19124       // Find a legal type for the vector store.
19125       unsigned Elts = (i + 1) * NumMemElts;
19126       EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
19127       unsigned IsFast = 0;
19128 
19129       // Break early when size is too large to be legal.
19130       if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
19131         break;
19132 
19133       if (TLI.isTypeLegal(Ty) &&
19134           TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
19135           TLI.allowsMemoryAccess(Context, DL, Ty,
19136                                  *FirstInChain->getMemOperand(), &IsFast) &&
19137           IsFast)
19138         NumStoresToMerge = i + 1;
19139     }
19140 
19141     // Check if we found a legal integer type creating a meaningful
19142     // merge.
19143     if (NumStoresToMerge < 2) {
19144       // We know that candidate stores are in order and of correct
19145       // shape. While there is no mergeable sequence from the
19146       // beginning one may start later in the sequence. The only
19147       // reason a merge of size N could have failed where another of
19148       // the same size would not have, is if the alignment has
19149       // improved. Drop as many candidates as we can here.
19150       unsigned NumSkip = 1;
19151       while ((NumSkip < NumConsecutiveStores) &&
19152              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
19153         NumSkip++;
19154 
19155       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
19156       NumConsecutiveStores -= NumSkip;
19157       continue;
19158     }
19159 
19160     // Check that we can merge these candidates without causing a cycle.
19161     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStoresToMerge,
19162                                                   RootNode)) {
19163       StoreNodes.erase(StoreNodes.begin(),
19164                        StoreNodes.begin() + NumStoresToMerge);
19165       NumConsecutiveStores -= NumStoresToMerge;
19166       continue;
19167     }
19168 
19169     MadeChange |= mergeStoresOfConstantsOrVecElts(
19170         StoreNodes, MemVT, NumStoresToMerge, /*IsConstantSrc*/ false,
19171         /*UseVector*/ true, /*UseTrunc*/ false);
19172 
19173     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumStoresToMerge);
19174     NumConsecutiveStores -= NumStoresToMerge;
19175   }
19176   return MadeChange;
19177 }
19178 
tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors,bool IsNonTemporalStore,bool IsNonTemporalLoad)19179 bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
19180                                        unsigned NumConsecutiveStores, EVT MemVT,
19181                                        SDNode *RootNode, bool AllowVectors,
19182                                        bool IsNonTemporalStore,
19183                                        bool IsNonTemporalLoad) {
19184   LLVMContext &Context = *DAG.getContext();
19185   const DataLayout &DL = DAG.getDataLayout();
19186   int64_t ElementSizeBytes = MemVT.getStoreSize();
19187   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
19188   bool MadeChange = false;
19189 
19190   // Look for load nodes which are used by the stored values.
19191   SmallVector<MemOpLink, 8> LoadNodes;
19192 
19193   // Find acceptable loads. Loads need to have the same chain (token factor),
19194   // must not be zext, volatile, indexed, and they must be consecutive.
19195   BaseIndexOffset LdBasePtr;
19196 
19197   for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
19198     StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
19199     SDValue Val = peekThroughBitcasts(St->getValue());
19200     LoadSDNode *Ld = cast<LoadSDNode>(Val);
19201 
19202     BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
19203     // If this is not the first ptr that we check.
19204     int64_t LdOffset = 0;
19205     if (LdBasePtr.getBase().getNode()) {
19206       // The base ptr must be the same.
19207       if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
19208         break;
19209     } else {
19210       // Check that all other base pointers are the same as this one.
19211       LdBasePtr = LdPtr;
19212     }
19213 
19214     // We found a potential memory operand to merge.
19215     LoadNodes.push_back(MemOpLink(Ld, LdOffset));
19216   }
19217 
19218   while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
19219     Align RequiredAlignment;
19220     bool NeedRotate = false;
19221     if (LoadNodes.size() == 2) {
19222       // If we have load/store pair instructions and we only have two values,
19223       // don't bother merging.
19224       if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
19225           StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
19226         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
19227         LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
19228         break;
19229       }
19230       // If the loads are reversed, see if we can rotate the halves into place.
19231       int64_t Offset0 = LoadNodes[0].OffsetFromBase;
19232       int64_t Offset1 = LoadNodes[1].OffsetFromBase;
19233       EVT PairVT = EVT::getIntegerVT(Context, ElementSizeBytes * 8 * 2);
19234       if (Offset0 - Offset1 == ElementSizeBytes &&
19235           (hasOperation(ISD::ROTL, PairVT) ||
19236            hasOperation(ISD::ROTR, PairVT))) {
19237         std::swap(LoadNodes[0], LoadNodes[1]);
19238         NeedRotate = true;
19239       }
19240     }
19241     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
19242     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
19243     Align FirstStoreAlign = FirstInChain->getAlign();
19244     LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
19245 
19246     // Scan the memory operations on the chain and find the first
19247     // non-consecutive load memory address. These variables hold the index in
19248     // the store node array.
19249 
19250     unsigned LastConsecutiveLoad = 1;
19251 
19252     // This variable refers to the size and not index in the array.
19253     unsigned LastLegalVectorType = 1;
19254     unsigned LastLegalIntegerType = 1;
19255     bool isDereferenceable = true;
19256     bool DoIntegerTruncate = false;
19257     int64_t StartAddress = LoadNodes[0].OffsetFromBase;
19258     SDValue LoadChain = FirstLoad->getChain();
19259     for (unsigned i = 1; i < LoadNodes.size(); ++i) {
19260       // All loads must share the same chain.
19261       if (LoadNodes[i].MemNode->getChain() != LoadChain)
19262         break;
19263 
19264       int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
19265       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
19266         break;
19267       LastConsecutiveLoad = i;
19268 
19269       if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
19270         isDereferenceable = false;
19271 
19272       // Find a legal type for the vector store.
19273       unsigned Elts = (i + 1) * NumMemElts;
19274       EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
19275 
19276       // Break early when size is too large to be legal.
19277       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
19278         break;
19279 
19280       unsigned IsFastSt = 0;
19281       unsigned IsFastLd = 0;
19282       // Don't try vector types if we need a rotate. We may still fail the
19283       // legality checks for the integer type, but we can't handle the rotate
19284       // case with vectors.
19285       // FIXME: We could use a shuffle in place of the rotate.
19286       if (!NeedRotate && TLI.isTypeLegal(StoreTy) &&
19287           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
19288                                DAG.getMachineFunction()) &&
19289           TLI.allowsMemoryAccess(Context, DL, StoreTy,
19290                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
19291           IsFastSt &&
19292           TLI.allowsMemoryAccess(Context, DL, StoreTy,
19293                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
19294           IsFastLd) {
19295         LastLegalVectorType = i + 1;
19296       }
19297 
19298       // Find a legal type for the integer store.
19299       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
19300       StoreTy = EVT::getIntegerVT(Context, SizeInBits);
19301       if (TLI.isTypeLegal(StoreTy) &&
19302           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
19303                                DAG.getMachineFunction()) &&
19304           TLI.allowsMemoryAccess(Context, DL, StoreTy,
19305                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
19306           IsFastSt &&
19307           TLI.allowsMemoryAccess(Context, DL, StoreTy,
19308                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
19309           IsFastLd) {
19310         LastLegalIntegerType = i + 1;
19311         DoIntegerTruncate = false;
19312         // Or check whether a truncstore and extload is legal.
19313       } else if (TLI.getTypeAction(Context, StoreTy) ==
19314                  TargetLowering::TypePromoteInteger) {
19315         EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
19316         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
19317             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
19318                                  DAG.getMachineFunction()) &&
19319             TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) &&
19320             TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) &&
19321             TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
19322             TLI.allowsMemoryAccess(Context, DL, StoreTy,
19323                                    *FirstInChain->getMemOperand(), &IsFastSt) &&
19324             IsFastSt &&
19325             TLI.allowsMemoryAccess(Context, DL, StoreTy,
19326                                    *FirstLoad->getMemOperand(), &IsFastLd) &&
19327             IsFastLd) {
19328           LastLegalIntegerType = i + 1;
19329           DoIntegerTruncate = true;
19330         }
19331       }
19332     }
19333 
19334     // Only use vector types if the vector type is larger than the integer
19335     // type. If they are the same, use integers.
19336     bool UseVectorTy =
19337         LastLegalVectorType > LastLegalIntegerType && AllowVectors;
19338     unsigned LastLegalType =
19339         std::max(LastLegalVectorType, LastLegalIntegerType);
19340 
19341     // We add +1 here because the LastXXX variables refer to location while
19342     // the NumElem refers to array/index size.
19343     unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
19344     NumElem = std::min(LastLegalType, NumElem);
19345     Align FirstLoadAlign = FirstLoad->getAlign();
19346 
19347     if (NumElem < 2) {
19348       // We know that candidate stores are in order and of correct
19349       // shape. While there is no mergeable sequence from the
19350       // beginning one may start later in the sequence. The only
19351       // reason a merge of size N could have failed where another of
19352       // the same size would not have is if the alignment or either
19353       // the load or store has improved. Drop as many candidates as we
19354       // can here.
19355       unsigned NumSkip = 1;
19356       while ((NumSkip < LoadNodes.size()) &&
19357              (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
19358              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
19359         NumSkip++;
19360       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
19361       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
19362       NumConsecutiveStores -= NumSkip;
19363       continue;
19364     }
19365 
19366     // Check that we can merge these candidates without causing a cycle.
19367     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
19368                                                   RootNode)) {
19369       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
19370       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
19371       NumConsecutiveStores -= NumElem;
19372       continue;
19373     }
19374 
19375     // Find if it is better to use vectors or integers to load and store
19376     // to memory.
19377     EVT JointMemOpVT;
19378     if (UseVectorTy) {
19379       // Find a legal type for the vector store.
19380       unsigned Elts = NumElem * NumMemElts;
19381       JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
19382     } else {
19383       unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
19384       JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
19385     }
19386 
19387     SDLoc LoadDL(LoadNodes[0].MemNode);
19388     SDLoc StoreDL(StoreNodes[0].MemNode);
19389 
19390     // The merged loads are required to have the same incoming chain, so
19391     // using the first's chain is acceptable.
19392 
19393     SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
19394     AddToWorklist(NewStoreChain.getNode());
19395 
19396     MachineMemOperand::Flags LdMMOFlags =
19397         isDereferenceable ? MachineMemOperand::MODereferenceable
19398                           : MachineMemOperand::MONone;
19399     if (IsNonTemporalLoad)
19400       LdMMOFlags |= MachineMemOperand::MONonTemporal;
19401 
19402     MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
19403                                               ? MachineMemOperand::MONonTemporal
19404                                               : MachineMemOperand::MONone;
19405 
19406     SDValue NewLoad, NewStore;
19407     if (UseVectorTy || !DoIntegerTruncate) {
19408       NewLoad = DAG.getLoad(
19409           JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(),
19410           FirstLoad->getPointerInfo(), FirstLoadAlign, LdMMOFlags);
19411       SDValue StoreOp = NewLoad;
19412       if (NeedRotate) {
19413         unsigned LoadWidth = ElementSizeBytes * 8 * 2;
19414         assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
19415                "Unexpected type for rotate-able load pair");
19416         SDValue RotAmt =
19417             DAG.getShiftAmountConstant(LoadWidth / 2, JointMemOpVT, LoadDL);
19418         // Target can convert to the identical ROTR if it does not have ROTL.
19419         StoreOp = DAG.getNode(ISD::ROTL, LoadDL, JointMemOpVT, NewLoad, RotAmt);
19420       }
19421       NewStore = DAG.getStore(
19422           NewStoreChain, StoreDL, StoreOp, FirstInChain->getBasePtr(),
19423           FirstInChain->getPointerInfo(), FirstStoreAlign, StMMOFlags);
19424     } else { // This must be the truncstore/extload case
19425       EVT ExtendedTy =
19426           TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
19427       NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
19428                                FirstLoad->getChain(), FirstLoad->getBasePtr(),
19429                                FirstLoad->getPointerInfo(), JointMemOpVT,
19430                                FirstLoadAlign, LdMMOFlags);
19431       NewStore = DAG.getTruncStore(
19432           NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
19433           FirstInChain->getPointerInfo(), JointMemOpVT,
19434           FirstInChain->getAlign(), FirstInChain->getMemOperand()->getFlags());
19435     }
19436 
19437     // Transfer chain users from old loads to the new load.
19438     for (unsigned i = 0; i < NumElem; ++i) {
19439       LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
19440       DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
19441                                     SDValue(NewLoad.getNode(), 1));
19442     }
19443 
19444     // Replace all stores with the new store. Recursively remove corresponding
19445     // values if they are no longer used.
19446     for (unsigned i = 0; i < NumElem; ++i) {
19447       SDValue Val = StoreNodes[i].MemNode->getOperand(1);
19448       CombineTo(StoreNodes[i].MemNode, NewStore);
19449       if (Val->use_empty())
19450         recursivelyDeleteUnusedNodes(Val.getNode());
19451     }
19452 
19453     MadeChange = true;
19454     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
19455     LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
19456     NumConsecutiveStores -= NumElem;
19457   }
19458   return MadeChange;
19459 }
19460 
mergeConsecutiveStores(StoreSDNode * St)19461 bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
19462   if (OptLevel == CodeGenOpt::None || !EnableStoreMerging)
19463     return false;
19464 
19465   // TODO: Extend this function to merge stores of scalable vectors.
19466   // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
19467   // store since we know <vscale x 16 x i8> is exactly twice as large as
19468   // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
19469   EVT MemVT = St->getMemoryVT();
19470   if (MemVT.isScalableVector())
19471     return false;
19472   if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
19473     return false;
19474 
19475   // This function cannot currently deal with non-byte-sized memory sizes.
19476   int64_t ElementSizeBytes = MemVT.getStoreSize();
19477   if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
19478     return false;
19479 
19480   // Do not bother looking at stored values that are not constants, loads, or
19481   // extracted vector elements.
19482   SDValue StoredVal = peekThroughBitcasts(St->getValue());
19483   const StoreSource StoreSrc = getStoreSource(StoredVal);
19484   if (StoreSrc == StoreSource::Unknown)
19485     return false;
19486 
19487   SmallVector<MemOpLink, 8> StoreNodes;
19488   SDNode *RootNode;
19489   // Find potential store merge candidates by searching through chain sub-DAG
19490   getStoreMergeCandidates(St, StoreNodes, RootNode);
19491 
19492   // Check if there is anything to merge.
19493   if (StoreNodes.size() < 2)
19494     return false;
19495 
19496   // Sort the memory operands according to their distance from the
19497   // base pointer.
19498   llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) {
19499     return LHS.OffsetFromBase < RHS.OffsetFromBase;
19500   });
19501 
19502   bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
19503       Attribute::NoImplicitFloat);
19504   bool IsNonTemporalStore = St->isNonTemporal();
19505   bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
19506                            cast<LoadSDNode>(StoredVal)->isNonTemporal();
19507 
19508   // Store Merge attempts to merge the lowest stores. This generally
19509   // works out as if successful, as the remaining stores are checked
19510   // after the first collection of stores is merged. However, in the
19511   // case that a non-mergeable store is found first, e.g., {p[-2],
19512   // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
19513   // mergeable cases. To prevent this, we prune such stores from the
19514   // front of StoreNodes here.
19515   bool MadeChange = false;
19516   while (StoreNodes.size() > 1) {
19517     unsigned NumConsecutiveStores =
19518         getConsecutiveStores(StoreNodes, ElementSizeBytes);
19519     // There are no more stores in the list to examine.
19520     if (NumConsecutiveStores == 0)
19521       return MadeChange;
19522 
19523     // We have at least 2 consecutive stores. Try to merge them.
19524     assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
19525     switch (StoreSrc) {
19526     case StoreSource::Constant:
19527       MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
19528                                              MemVT, RootNode, AllowVectors);
19529       break;
19530 
19531     case StoreSource::Extract:
19532       MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
19533                                             MemVT, RootNode);
19534       break;
19535 
19536     case StoreSource::Load:
19537       MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
19538                                          MemVT, RootNode, AllowVectors,
19539                                          IsNonTemporalStore, IsNonTemporalLoad);
19540       break;
19541 
19542     default:
19543       llvm_unreachable("Unhandled store source type");
19544     }
19545   }
19546   return MadeChange;
19547 }
19548 
replaceStoreChain(StoreSDNode * ST,SDValue BetterChain)19549 SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
19550   SDLoc SL(ST);
19551   SDValue ReplStore;
19552 
19553   // Replace the chain to avoid dependency.
19554   if (ST->isTruncatingStore()) {
19555     ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(),
19556                                   ST->getBasePtr(), ST->getMemoryVT(),
19557                                   ST->getMemOperand());
19558   } else {
19559     ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(),
19560                              ST->getMemOperand());
19561   }
19562 
19563   // Create token to keep both nodes around.
19564   SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
19565                               MVT::Other, ST->getChain(), ReplStore);
19566 
19567   // Make sure the new and old chains are cleaned up.
19568   AddToWorklist(Token.getNode());
19569 
19570   // Don't add users to work list.
19571   return CombineTo(ST, Token, false);
19572 }
19573 
replaceStoreOfFPConstant(StoreSDNode * ST)19574 SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
19575   SDValue Value = ST->getValue();
19576   if (Value.getOpcode() == ISD::TargetConstantFP)
19577     return SDValue();
19578 
19579   if (!ISD::isNormalStore(ST))
19580     return SDValue();
19581 
19582   SDLoc DL(ST);
19583 
19584   SDValue Chain = ST->getChain();
19585   SDValue Ptr = ST->getBasePtr();
19586 
19587   const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Value);
19588 
19589   // NOTE: If the original store is volatile, this transform must not increase
19590   // the number of stores.  For example, on x86-32 an f64 can be stored in one
19591   // processor operation but an i64 (which is not legal) requires two.  So the
19592   // transform should not be done in this case.
19593 
19594   SDValue Tmp;
19595   switch (CFP->getSimpleValueType(0).SimpleTy) {
19596   default:
19597     llvm_unreachable("Unknown FP type");
19598   case MVT::f16:    // We don't do this for these yet.
19599   case MVT::bf16:
19600   case MVT::f80:
19601   case MVT::f128:
19602   case MVT::ppcf128:
19603     return SDValue();
19604   case MVT::f32:
19605     if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
19606         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
19607       Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
19608                             bitcastToAPInt().getZExtValue(), SDLoc(CFP),
19609                             MVT::i32);
19610       return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
19611     }
19612 
19613     return SDValue();
19614   case MVT::f64:
19615     if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
19616          ST->isSimple()) ||
19617         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
19618       Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
19619                             getZExtValue(), SDLoc(CFP), MVT::i64);
19620       return DAG.getStore(Chain, DL, Tmp,
19621                           Ptr, ST->getMemOperand());
19622     }
19623 
19624     if (ST->isSimple() &&
19625         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
19626       // Many FP stores are not made apparent until after legalize, e.g. for
19627       // argument passing.  Since this is so common, custom legalize the
19628       // 64-bit integer store into two 32-bit stores.
19629       uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
19630       SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
19631       SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
19632       if (DAG.getDataLayout().isBigEndian())
19633         std::swap(Lo, Hi);
19634 
19635       MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
19636       AAMDNodes AAInfo = ST->getAAInfo();
19637 
19638       SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
19639                                  ST->getOriginalAlign(), MMOFlags, AAInfo);
19640       Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(4), DL);
19641       SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr,
19642                                  ST->getPointerInfo().getWithOffset(4),
19643                                  ST->getOriginalAlign(), MMOFlags, AAInfo);
19644       return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
19645                          St0, St1);
19646     }
19647 
19648     return SDValue();
19649   }
19650 }
19651 
visitSTORE(SDNode * N)19652 SDValue DAGCombiner::visitSTORE(SDNode *N) {
19653   StoreSDNode *ST  = cast<StoreSDNode>(N);
19654   SDValue Chain = ST->getChain();
19655   SDValue Value = ST->getValue();
19656   SDValue Ptr   = ST->getBasePtr();
19657 
19658   // If this is a store of a bit convert, store the input value if the
19659   // resultant store does not need a higher alignment than the original.
19660   if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
19661       ST->isUnindexed()) {
19662     EVT SVT = Value.getOperand(0).getValueType();
19663     // If the store is volatile, we only want to change the store type if the
19664     // resulting store is legal. Otherwise we might increase the number of
19665     // memory accesses. We don't care if the original type was legal or not
19666     // as we assume software couldn't rely on the number of accesses of an
19667     // illegal type.
19668     // TODO: May be able to relax for unordered atomics (see D66309)
19669     if (((!LegalOperations && ST->isSimple()) ||
19670          TLI.isOperationLegal(ISD::STORE, SVT)) &&
19671         TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT,
19672                                      DAG, *ST->getMemOperand())) {
19673       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
19674                           ST->getMemOperand());
19675     }
19676   }
19677 
19678   // Turn 'store undef, Ptr' -> nothing.
19679   if (Value.isUndef() && ST->isUnindexed())
19680     return Chain;
19681 
19682   // Try to infer better alignment information than the store already has.
19683   if (OptLevel != CodeGenOpt::None && ST->isUnindexed() && !ST->isAtomic()) {
19684     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
19685       if (*Alignment > ST->getAlign() &&
19686           isAligned(*Alignment, ST->getSrcValueOffset())) {
19687         SDValue NewStore =
19688             DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(),
19689                               ST->getMemoryVT(), *Alignment,
19690                               ST->getMemOperand()->getFlags(), ST->getAAInfo());
19691         // NewStore will always be N as we are only refining the alignment
19692         assert(NewStore.getNode() == N);
19693         (void)NewStore;
19694       }
19695     }
19696   }
19697 
19698   // Try transforming a pair floating point load / store ops to integer
19699   // load / store ops.
19700   if (SDValue NewST = TransformFPLoadStorePair(N))
19701     return NewST;
19702 
19703   // Try transforming several stores into STORE (BSWAP).
19704   if (SDValue Store = mergeTruncStores(ST))
19705     return Store;
19706 
19707   if (ST->isUnindexed()) {
19708     // Walk up chain skipping non-aliasing memory nodes, on this store and any
19709     // adjacent stores.
19710     if (findBetterNeighborChains(ST)) {
19711       // replaceStoreChain uses CombineTo, which handled all of the worklist
19712       // manipulation. Return the original node to not do anything else.
19713       return SDValue(ST, 0);
19714     }
19715     Chain = ST->getChain();
19716   }
19717 
19718   // FIXME: is there such a thing as a truncating indexed store?
19719   if (ST->isTruncatingStore() && ST->isUnindexed() &&
19720       Value.getValueType().isInteger() &&
19721       (!isa<ConstantSDNode>(Value) ||
19722        !cast<ConstantSDNode>(Value)->isOpaque())) {
19723     // Convert a truncating store of a extension into a standard store.
19724     if ((Value.getOpcode() == ISD::ZERO_EXTEND ||
19725          Value.getOpcode() == ISD::SIGN_EXTEND ||
19726          Value.getOpcode() == ISD::ANY_EXTEND) &&
19727         Value.getOperand(0).getValueType() == ST->getMemoryVT() &&
19728         TLI.isOperationLegalOrCustom(ISD::STORE, ST->getMemoryVT()))
19729       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
19730                           ST->getMemOperand());
19731 
19732     APInt TruncDemandedBits =
19733         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
19734                              ST->getMemoryVT().getScalarSizeInBits());
19735 
19736     // See if we can simplify the operation with SimplifyDemandedBits, which
19737     // only works if the value has a single use.
19738     AddToWorklist(Value.getNode());
19739     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
19740       // Re-visit the store if anything changed and the store hasn't been merged
19741       // with another node (N is deleted) SimplifyDemandedBits will add Value's
19742       // node back to the worklist if necessary, but we also need to re-visit
19743       // the Store node itself.
19744       if (N->getOpcode() != ISD::DELETED_NODE)
19745         AddToWorklist(N);
19746       return SDValue(N, 0);
19747     }
19748 
19749     // Otherwise, see if we can simplify the input to this truncstore with
19750     // knowledge that only the low bits are being used.  For example:
19751     // "truncstore (or (shl x, 8), y), i8"  -> "truncstore y, i8"
19752     if (SDValue Shorter =
19753             TLI.SimplifyMultipleUseDemandedBits(Value, TruncDemandedBits, DAG))
19754       return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr, ST->getMemoryVT(),
19755                                ST->getMemOperand());
19756 
19757     // If we're storing a truncated constant, see if we can simplify it.
19758     // TODO: Move this to targetShrinkDemandedConstant?
19759     if (auto *Cst = dyn_cast<ConstantSDNode>(Value))
19760       if (!Cst->isOpaque()) {
19761         const APInt &CValue = Cst->getAPIntValue();
19762         APInt NewVal = CValue & TruncDemandedBits;
19763         if (NewVal != CValue) {
19764           SDValue Shorter =
19765               DAG.getConstant(NewVal, SDLoc(N), Value.getValueType());
19766           return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr,
19767                                    ST->getMemoryVT(), ST->getMemOperand());
19768         }
19769       }
19770   }
19771 
19772   // If this is a load followed by a store to the same location, then the store
19773   // is dead/noop.
19774   // TODO: Can relax for unordered atomics (see D66309)
19775   if (LoadSDNode *Ld = dyn_cast<LoadSDNode>(Value)) {
19776     if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
19777         ST->isUnindexed() && ST->isSimple() &&
19778         Ld->getAddressSpace() == ST->getAddressSpace() &&
19779         // There can't be any side effects between the load and store, such as
19780         // a call or store.
19781         Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1))) {
19782       // The store is dead, remove it.
19783       return Chain;
19784     }
19785   }
19786 
19787   // TODO: Can relax for unordered atomics (see D66309)
19788   if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
19789     if (ST->isUnindexed() && ST->isSimple() &&
19790         ST1->isUnindexed() && ST1->isSimple()) {
19791       if (OptLevel != CodeGenOpt::None && ST1->getBasePtr() == Ptr &&
19792           ST1->getValue() == Value && ST->getMemoryVT() == ST1->getMemoryVT() &&
19793           ST->getAddressSpace() == ST1->getAddressSpace()) {
19794         // If this is a store followed by a store with the same value to the
19795         // same location, then the store is dead/noop.
19796         return Chain;
19797       }
19798 
19799       if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() &&
19800           !ST1->getBasePtr().isUndef() &&
19801           // BaseIndexOffset and the code below requires knowing the size
19802           // of a vector, so bail out if MemoryVT is scalable.
19803           !ST->getMemoryVT().isScalableVector() &&
19804           !ST1->getMemoryVT().isScalableVector() &&
19805           ST->getAddressSpace() == ST1->getAddressSpace()) {
19806         const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG);
19807         const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG);
19808         unsigned STBitSize = ST->getMemoryVT().getFixedSizeInBits();
19809         unsigned ChainBitSize = ST1->getMemoryVT().getFixedSizeInBits();
19810         // If this is a store who's preceding store to a subset of the current
19811         // location and no one other node is chained to that store we can
19812         // effectively drop the store. Do not remove stores to undef as they may
19813         // be used as data sinks.
19814         if (STBase.contains(DAG, STBitSize, ChainBase, ChainBitSize)) {
19815           CombineTo(ST1, ST1->getChain());
19816           return SDValue();
19817         }
19818       }
19819     }
19820   }
19821 
19822   // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
19823   // truncating store.  We can do this even if this is already a truncstore.
19824   if ((Value.getOpcode() == ISD::FP_ROUND ||
19825        Value.getOpcode() == ISD::TRUNCATE) &&
19826       Value->hasOneUse() && ST->isUnindexed() &&
19827       TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
19828                                ST->getMemoryVT(), LegalOperations)) {
19829     return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0),
19830                              Ptr, ST->getMemoryVT(), ST->getMemOperand());
19831   }
19832 
19833   // Always perform this optimization before types are legal. If the target
19834   // prefers, also try this after legalization to catch stores that were created
19835   // by intrinsics or other nodes.
19836   if (!LegalTypes || (TLI.mergeStoresAfterLegalization(ST->getMemoryVT()))) {
19837     while (true) {
19838       // There can be multiple store sequences on the same chain.
19839       // Keep trying to merge store sequences until we are unable to do so
19840       // or until we merge the last store on the chain.
19841       bool Changed = mergeConsecutiveStores(ST);
19842       if (!Changed) break;
19843       // Return N as merge only uses CombineTo and no worklist clean
19844       // up is necessary.
19845       if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
19846         return SDValue(N, 0);
19847     }
19848   }
19849 
19850   // Try transforming N to an indexed store.
19851   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
19852     return SDValue(N, 0);
19853 
19854   // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
19855   //
19856   // Make sure to do this only after attempting to merge stores in order to
19857   //  avoid changing the types of some subset of stores due to visit order,
19858   //  preventing their merging.
19859   if (isa<ConstantFPSDNode>(ST->getValue())) {
19860     if (SDValue NewSt = replaceStoreOfFPConstant(ST))
19861       return NewSt;
19862   }
19863 
19864   if (SDValue NewSt = splitMergedValStore(ST))
19865     return NewSt;
19866 
19867   return ReduceLoadOpStoreWidth(N);
19868 }
19869 
visitLIFETIME_END(SDNode * N)19870 SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
19871   const auto *LifetimeEnd = cast<LifetimeSDNode>(N);
19872   if (!LifetimeEnd->hasOffset())
19873     return SDValue();
19874 
19875   const BaseIndexOffset LifetimeEndBase(N->getOperand(1), SDValue(),
19876                                         LifetimeEnd->getOffset(), false);
19877 
19878   // We walk up the chains to find stores.
19879   SmallVector<SDValue, 8> Chains = {N->getOperand(0)};
19880   while (!Chains.empty()) {
19881     SDValue Chain = Chains.pop_back_val();
19882     if (!Chain.hasOneUse())
19883       continue;
19884     switch (Chain.getOpcode()) {
19885     case ISD::TokenFactor:
19886       for (unsigned Nops = Chain.getNumOperands(); Nops;)
19887         Chains.push_back(Chain.getOperand(--Nops));
19888       break;
19889     case ISD::LIFETIME_START:
19890     case ISD::LIFETIME_END:
19891       // We can forward past any lifetime start/end that can be proven not to
19892       // alias the node.
19893       if (!mayAlias(Chain.getNode(), N))
19894         Chains.push_back(Chain.getOperand(0));
19895       break;
19896     case ISD::STORE: {
19897       StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain);
19898       // TODO: Can relax for unordered atomics (see D66309)
19899       if (!ST->isSimple() || ST->isIndexed())
19900         continue;
19901       const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
19902       // The bounds of a scalable store are not known until runtime, so this
19903       // store cannot be elided.
19904       if (StoreSize.isScalable())
19905         continue;
19906       const BaseIndexOffset StoreBase = BaseIndexOffset::match(ST, DAG);
19907       // If we store purely within object bounds just before its lifetime ends,
19908       // we can remove the store.
19909       if (LifetimeEndBase.contains(DAG, LifetimeEnd->getSize() * 8, StoreBase,
19910                                    StoreSize.getFixedValue() * 8)) {
19911         LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
19912                    dbgs() << "\nwithin LIFETIME_END of : ";
19913                    LifetimeEndBase.dump(); dbgs() << "\n");
19914         CombineTo(ST, ST->getChain());
19915         return SDValue(N, 0);
19916       }
19917     }
19918     }
19919   }
19920   return SDValue();
19921 }
19922 
19923 /// For the instruction sequence of store below, F and I values
19924 /// are bundled together as an i64 value before being stored into memory.
19925 /// Sometimes it is more efficent to generate separate stores for F and I,
19926 /// which can remove the bitwise instructions or sink them to colder places.
19927 ///
19928 ///   (store (or (zext (bitcast F to i32) to i64),
19929 ///              (shl (zext I to i64), 32)), addr)  -->
19930 ///   (store F, addr) and (store I, addr+4)
19931 ///
19932 /// Similarly, splitting for other merged store can also be beneficial, like:
19933 /// For pair of {i32, i32}, i64 store --> two i32 stores.
19934 /// For pair of {i32, i16}, i64 store --> two i32 stores.
19935 /// For pair of {i16, i16}, i32 store --> two i16 stores.
19936 /// For pair of {i16, i8},  i32 store --> two i16 stores.
19937 /// For pair of {i8, i8},   i16 store --> two i8 stores.
19938 ///
19939 /// We allow each target to determine specifically which kind of splitting is
19940 /// supported.
19941 ///
19942 /// The store patterns are commonly seen from the simple code snippet below
19943 /// if only std::make_pair(...) is sroa transformed before inlined into hoo.
19944 ///   void goo(const std::pair<int, float> &);
19945 ///   hoo() {
19946 ///     ...
19947 ///     goo(std::make_pair(tmp, ftmp));
19948 ///     ...
19949 ///   }
19950 ///
splitMergedValStore(StoreSDNode * ST)19951 SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
19952   if (OptLevel == CodeGenOpt::None)
19953     return SDValue();
19954 
19955   // Can't change the number of memory accesses for a volatile store or break
19956   // atomicity for an atomic one.
19957   if (!ST->isSimple())
19958     return SDValue();
19959 
19960   SDValue Val = ST->getValue();
19961   SDLoc DL(ST);
19962 
19963   // Match OR operand.
19964   if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
19965     return SDValue();
19966 
19967   // Match SHL operand and get Lower and Higher parts of Val.
19968   SDValue Op1 = Val.getOperand(0);
19969   SDValue Op2 = Val.getOperand(1);
19970   SDValue Lo, Hi;
19971   if (Op1.getOpcode() != ISD::SHL) {
19972     std::swap(Op1, Op2);
19973     if (Op1.getOpcode() != ISD::SHL)
19974       return SDValue();
19975   }
19976   Lo = Op2;
19977   Hi = Op1.getOperand(0);
19978   if (!Op1.hasOneUse())
19979     return SDValue();
19980 
19981   // Match shift amount to HalfValBitSize.
19982   unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
19983   ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1));
19984   if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
19985     return SDValue();
19986 
19987   // Lo and Hi are zero-extended from int with size less equal than 32
19988   // to i64.
19989   if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
19990       !Lo.getOperand(0).getValueType().isScalarInteger() ||
19991       Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize ||
19992       Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
19993       !Hi.getOperand(0).getValueType().isScalarInteger() ||
19994       Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize)
19995     return SDValue();
19996 
19997   // Use the EVT of low and high parts before bitcast as the input
19998   // of target query.
19999   EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST)
20000                   ? Lo.getOperand(0).getValueType()
20001                   : Lo.getValueType();
20002   EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST)
20003                    ? Hi.getOperand(0).getValueType()
20004                    : Hi.getValueType();
20005   if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy))
20006     return SDValue();
20007 
20008   // Start to split store.
20009   MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
20010   AAMDNodes AAInfo = ST->getAAInfo();
20011 
20012   // Change the sizes of Lo and Hi's value types to HalfValBitSize.
20013   EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize);
20014   Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0));
20015   Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0));
20016 
20017   SDValue Chain = ST->getChain();
20018   SDValue Ptr = ST->getBasePtr();
20019   // Lower value store.
20020   SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
20021                              ST->getOriginalAlign(), MMOFlags, AAInfo);
20022   Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(HalfValBitSize / 8), DL);
20023   // Higher value store.
20024   SDValue St1 = DAG.getStore(
20025       St0, DL, Hi, Ptr, ST->getPointerInfo().getWithOffset(HalfValBitSize / 8),
20026       ST->getOriginalAlign(), MMOFlags, AAInfo);
20027   return St1;
20028 }
20029 
20030 // Merge an insertion into an existing shuffle:
20031 // (insert_vector_elt (vector_shuffle X, Y, Mask),
20032 //                   .(extract_vector_elt X, N), InsIndex)
20033 //   --> (vector_shuffle X, Y, NewMask)
20034 //  and variations where shuffle operands may be CONCAT_VECTORS.
mergeEltWithShuffle(SDValue & X,SDValue & Y,ArrayRef<int> Mask,SmallVectorImpl<int> & NewMask,SDValue Elt,unsigned InsIndex)20035 static bool mergeEltWithShuffle(SDValue &X, SDValue &Y, ArrayRef<int> Mask,
20036                                 SmallVectorImpl<int> &NewMask, SDValue Elt,
20037                                 unsigned InsIndex) {
20038   if (Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
20039       !isa<ConstantSDNode>(Elt.getOperand(1)))
20040     return false;
20041 
20042   // Vec's operand 0 is using indices from 0 to N-1 and
20043   // operand 1 from N to 2N - 1, where N is the number of
20044   // elements in the vectors.
20045   SDValue InsertVal0 = Elt.getOperand(0);
20046   int ElementOffset = -1;
20047 
20048   // We explore the inputs of the shuffle in order to see if we find the
20049   // source of the extract_vector_elt. If so, we can use it to modify the
20050   // shuffle rather than perform an insert_vector_elt.
20051   SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
20052   ArgWorkList.emplace_back(Mask.size(), Y);
20053   ArgWorkList.emplace_back(0, X);
20054 
20055   while (!ArgWorkList.empty()) {
20056     int ArgOffset;
20057     SDValue ArgVal;
20058     std::tie(ArgOffset, ArgVal) = ArgWorkList.pop_back_val();
20059 
20060     if (ArgVal == InsertVal0) {
20061       ElementOffset = ArgOffset;
20062       break;
20063     }
20064 
20065     // Peek through concat_vector.
20066     if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
20067       int CurrentArgOffset =
20068           ArgOffset + ArgVal.getValueType().getVectorNumElements();
20069       int Step = ArgVal.getOperand(0).getValueType().getVectorNumElements();
20070       for (SDValue Op : reverse(ArgVal->ops())) {
20071         CurrentArgOffset -= Step;
20072         ArgWorkList.emplace_back(CurrentArgOffset, Op);
20073       }
20074 
20075       // Make sure we went through all the elements and did not screw up index
20076       // computation.
20077       assert(CurrentArgOffset == ArgOffset);
20078     }
20079   }
20080 
20081   // If we failed to find a match, see if we can replace an UNDEF shuffle
20082   // operand.
20083   if (ElementOffset == -1) {
20084     if (!Y.isUndef() || InsertVal0.getValueType() != Y.getValueType())
20085       return false;
20086     ElementOffset = Mask.size();
20087     Y = InsertVal0;
20088   }
20089 
20090   NewMask.assign(Mask.begin(), Mask.end());
20091   NewMask[InsIndex] = ElementOffset + Elt.getConstantOperandVal(1);
20092   assert(NewMask[InsIndex] < (int)(2 * Mask.size()) && NewMask[InsIndex] >= 0 &&
20093          "NewMask[InsIndex] is out of bound");
20094   return true;
20095 }
20096 
20097 // Merge an insertion into an existing shuffle:
20098 // (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
20099 // InsIndex)
20100 //   --> (vector_shuffle X, Y) and variations where shuffle operands may be
20101 //   CONCAT_VECTORS.
mergeInsertEltWithShuffle(SDNode * N,unsigned InsIndex)20102 SDValue DAGCombiner::mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex) {
20103   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
20104          "Expected extract_vector_elt");
20105   SDValue InsertVal = N->getOperand(1);
20106   SDValue Vec = N->getOperand(0);
20107 
20108   auto *SVN = dyn_cast<ShuffleVectorSDNode>(Vec);
20109   if (!SVN || !Vec.hasOneUse())
20110     return SDValue();
20111 
20112   ArrayRef<int> Mask = SVN->getMask();
20113   SDValue X = Vec.getOperand(0);
20114   SDValue Y = Vec.getOperand(1);
20115 
20116   SmallVector<int, 16> NewMask(Mask);
20117   if (mergeEltWithShuffle(X, Y, Mask, NewMask, InsertVal, InsIndex)) {
20118     SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
20119         Vec.getValueType(), SDLoc(N), X, Y, NewMask, DAG);
20120     if (LegalShuffle)
20121       return LegalShuffle;
20122   }
20123 
20124   return SDValue();
20125 }
20126 
20127 // Convert a disguised subvector insertion into a shuffle:
20128 // insert_vector_elt V, (bitcast X from vector type), IdxC -->
20129 // bitcast(shuffle (bitcast V), (extended X), Mask)
20130 // Note: We do not use an insert_subvector node because that requires a
20131 // legal subvector type.
combineInsertEltToShuffle(SDNode * N,unsigned InsIndex)20132 SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
20133   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
20134          "Expected extract_vector_elt");
20135   SDValue InsertVal = N->getOperand(1);
20136 
20137   if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
20138       !InsertVal.getOperand(0).getValueType().isVector())
20139     return SDValue();
20140 
20141   SDValue SubVec = InsertVal.getOperand(0);
20142   SDValue DestVec = N->getOperand(0);
20143   EVT SubVecVT = SubVec.getValueType();
20144   EVT VT = DestVec.getValueType();
20145   unsigned NumSrcElts = SubVecVT.getVectorNumElements();
20146   // If the source only has a single vector element, the cost of creating adding
20147   // it to a vector is likely to exceed the cost of a insert_vector_elt.
20148   if (NumSrcElts == 1)
20149     return SDValue();
20150   unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
20151   unsigned NumMaskVals = ExtendRatio * NumSrcElts;
20152 
20153   // Step 1: Create a shuffle mask that implements this insert operation. The
20154   // vector that we are inserting into will be operand 0 of the shuffle, so
20155   // those elements are just 'i'. The inserted subvector is in the first
20156   // positions of operand 1 of the shuffle. Example:
20157   // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
20158   SmallVector<int, 16> Mask(NumMaskVals);
20159   for (unsigned i = 0; i != NumMaskVals; ++i) {
20160     if (i / NumSrcElts == InsIndex)
20161       Mask[i] = (i % NumSrcElts) + NumMaskVals;
20162     else
20163       Mask[i] = i;
20164   }
20165 
20166   // Bail out if the target can not handle the shuffle we want to create.
20167   EVT SubVecEltVT = SubVecVT.getVectorElementType();
20168   EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals);
20169   if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
20170     return SDValue();
20171 
20172   // Step 2: Create a wide vector from the inserted source vector by appending
20173   // undefined elements. This is the same size as our destination vector.
20174   SDLoc DL(N);
20175   SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT));
20176   ConcatOps[0] = SubVec;
20177   SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps);
20178 
20179   // Step 3: Shuffle in the padded subvector.
20180   SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec);
20181   SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask);
20182   AddToWorklist(PaddedSubV.getNode());
20183   AddToWorklist(DestVecBC.getNode());
20184   AddToWorklist(Shuf.getNode());
20185   return DAG.getBitcast(VT, Shuf);
20186 }
20187 
visitINSERT_VECTOR_ELT(SDNode * N)20188 SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
20189   SDValue InVec = N->getOperand(0);
20190   SDValue InVal = N->getOperand(1);
20191   SDValue EltNo = N->getOperand(2);
20192   SDLoc DL(N);
20193 
20194   EVT VT = InVec.getValueType();
20195   auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
20196 
20197   // Insert into out-of-bounds element is undefined.
20198   if (IndexC && VT.isFixedLengthVector() &&
20199       IndexC->getZExtValue() >= VT.getVectorNumElements())
20200     return DAG.getUNDEF(VT);
20201 
20202   // Remove redundant insertions:
20203   // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
20204   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20205       InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
20206     return InVec;
20207 
20208   if (!IndexC) {
20209     // If this is variable insert to undef vector, it might be better to splat:
20210     // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
20211     if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
20212       return DAG.getSplat(VT, DL, InVal);
20213     return SDValue();
20214   }
20215 
20216   if (VT.isScalableVector())
20217     return SDValue();
20218 
20219   unsigned NumElts = VT.getVectorNumElements();
20220 
20221   // We must know which element is being inserted for folds below here.
20222   unsigned Elt = IndexC->getZExtValue();
20223 
20224   // Handle <1 x ???> vector insertion special cases.
20225   if (NumElts == 1) {
20226     // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
20227     if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20228         InVal.getOperand(0).getValueType() == VT &&
20229         isNullConstant(InVal.getOperand(1)))
20230       return InVal.getOperand(0);
20231   }
20232 
20233   // Canonicalize insert_vector_elt dag nodes.
20234   // Example:
20235   // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
20236   // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
20237   //
20238   // Do this only if the child insert_vector node has one use; also
20239   // do this only if indices are both constants and Idx1 < Idx0.
20240   if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
20241       && isa<ConstantSDNode>(InVec.getOperand(2))) {
20242     unsigned OtherElt = InVec.getConstantOperandVal(2);
20243     if (Elt < OtherElt) {
20244       // Swap nodes.
20245       SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
20246                                   InVec.getOperand(0), InVal, EltNo);
20247       AddToWorklist(NewOp.getNode());
20248       return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()),
20249                          VT, NewOp, InVec.getOperand(1), InVec.getOperand(2));
20250     }
20251   }
20252 
20253   if (SDValue Shuf = mergeInsertEltWithShuffle(N, Elt))
20254     return Shuf;
20255 
20256   if (SDValue Shuf = combineInsertEltToShuffle(N, Elt))
20257     return Shuf;
20258 
20259   // Attempt to convert an insert_vector_elt chain into a legal build_vector.
20260   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) {
20261     // vXi1 vector - we don't need to recurse.
20262     if (NumElts == 1)
20263       return DAG.getBuildVector(VT, DL, {InVal});
20264 
20265     // If we haven't already collected the element, insert into the op list.
20266     EVT MaxEltVT = InVal.getValueType();
20267     auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
20268                                 unsigned Idx) {
20269       if (!Ops[Idx]) {
20270         Ops[Idx] = Elt;
20271         if (VT.isInteger()) {
20272           EVT EltVT = Elt.getValueType();
20273           MaxEltVT = MaxEltVT.bitsGE(EltVT) ? MaxEltVT : EltVT;
20274         }
20275       }
20276     };
20277 
20278     // Ensure all the operands are the same value type, fill any missing
20279     // operands with UNDEF and create the BUILD_VECTOR.
20280     auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops) {
20281       assert(Ops.size() == NumElts && "Unexpected vector size");
20282       for (SDValue &Op : Ops) {
20283         if (Op)
20284           Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, MaxEltVT) : Op;
20285         else
20286           Op = DAG.getUNDEF(MaxEltVT);
20287       }
20288       return DAG.getBuildVector(VT, DL, Ops);
20289     };
20290 
20291     SmallVector<SDValue, 8> Ops(NumElts, SDValue());
20292     Ops[Elt] = InVal;
20293 
20294     // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
20295     for (SDValue CurVec = InVec; CurVec;) {
20296       // UNDEF - build new BUILD_VECTOR from already inserted operands.
20297       if (CurVec.isUndef())
20298         return CanonicalizeBuildVector(Ops);
20299 
20300       // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
20301       if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
20302         for (unsigned I = 0; I != NumElts; ++I)
20303           AddBuildVectorOp(Ops, CurVec.getOperand(I), I);
20304         return CanonicalizeBuildVector(Ops);
20305       }
20306 
20307       // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
20308       if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
20309         AddBuildVectorOp(Ops, CurVec.getOperand(0), 0);
20310         return CanonicalizeBuildVector(Ops);
20311       }
20312 
20313       // INSERT_VECTOR_ELT - insert operand and continue up the chain.
20314       if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
20315         if (auto *CurIdx = dyn_cast<ConstantSDNode>(CurVec.getOperand(2)))
20316           if (CurIdx->getAPIntValue().ult(NumElts)) {
20317             unsigned Idx = CurIdx->getZExtValue();
20318             AddBuildVectorOp(Ops, CurVec.getOperand(1), Idx);
20319 
20320             // Found entire BUILD_VECTOR.
20321             if (all_of(Ops, [](SDValue Op) { return !!Op; }))
20322               return CanonicalizeBuildVector(Ops);
20323 
20324             CurVec = CurVec->getOperand(0);
20325             continue;
20326           }
20327 
20328       // VECTOR_SHUFFLE - if all the operands match the shuffle's sources,
20329       // update the shuffle mask (and second operand if we started with unary
20330       // shuffle) and create a new legal shuffle.
20331       if (CurVec.getOpcode() == ISD::VECTOR_SHUFFLE && CurVec.hasOneUse()) {
20332         auto *SVN = cast<ShuffleVectorSDNode>(CurVec);
20333         SDValue LHS = SVN->getOperand(0);
20334         SDValue RHS = SVN->getOperand(1);
20335         SmallVector<int, 16> Mask(SVN->getMask());
20336         bool Merged = true;
20337         for (auto I : enumerate(Ops)) {
20338           SDValue &Op = I.value();
20339           if (Op) {
20340             SmallVector<int, 16> NewMask;
20341             if (!mergeEltWithShuffle(LHS, RHS, Mask, NewMask, Op, I.index())) {
20342               Merged = false;
20343               break;
20344             }
20345             Mask = std::move(NewMask);
20346           }
20347         }
20348         if (Merged)
20349           if (SDValue NewShuffle =
20350                   TLI.buildLegalVectorShuffle(VT, DL, LHS, RHS, Mask, DAG))
20351             return NewShuffle;
20352       }
20353 
20354       // Failed to find a match in the chain - bail.
20355       break;
20356     }
20357 
20358     // See if we can fill in the missing constant elements as zeros.
20359     // TODO: Should we do this for any constant?
20360     APInt DemandedZeroElts = APInt::getZero(NumElts);
20361     for (unsigned I = 0; I != NumElts; ++I)
20362       if (!Ops[I])
20363         DemandedZeroElts.setBit(I);
20364 
20365     if (DAG.MaskedVectorIsZero(InVec, DemandedZeroElts)) {
20366       SDValue Zero = VT.isInteger() ? DAG.getConstant(0, DL, MaxEltVT)
20367                                     : DAG.getConstantFP(0, DL, MaxEltVT);
20368       for (unsigned I = 0; I != NumElts; ++I)
20369         if (!Ops[I])
20370           Ops[I] = Zero;
20371 
20372       return CanonicalizeBuildVector(Ops);
20373     }
20374   }
20375 
20376   return SDValue();
20377 }
20378 
scalarizeExtractedVectorLoad(SDNode * EVE,EVT InVecVT,SDValue EltNo,LoadSDNode * OriginalLoad)20379 SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
20380                                                   SDValue EltNo,
20381                                                   LoadSDNode *OriginalLoad) {
20382   assert(OriginalLoad->isSimple());
20383 
20384   EVT ResultVT = EVE->getValueType(0);
20385   EVT VecEltVT = InVecVT.getVectorElementType();
20386 
20387   // If the vector element type is not a multiple of a byte then we are unable
20388   // to correctly compute an address to load only the extracted element as a
20389   // scalar.
20390   if (!VecEltVT.isByteSized())
20391     return SDValue();
20392 
20393   ISD::LoadExtType ExtTy =
20394       ResultVT.bitsGT(VecEltVT) ? ISD::NON_EXTLOAD : ISD::EXTLOAD;
20395   if (!TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
20396       !TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
20397     return SDValue();
20398 
20399   Align Alignment = OriginalLoad->getAlign();
20400   MachinePointerInfo MPI;
20401   SDLoc DL(EVE);
20402   if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
20403     int Elt = ConstEltNo->getZExtValue();
20404     unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
20405     MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
20406     Alignment = commonAlignment(Alignment, PtrOff);
20407   } else {
20408     // Discard the pointer info except the address space because the memory
20409     // operand can't represent this new access since the offset is variable.
20410     MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
20411     Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
20412   }
20413 
20414   unsigned IsFast = 0;
20415   if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
20416                               OriginalLoad->getAddressSpace(), Alignment,
20417                               OriginalLoad->getMemOperand()->getFlags(),
20418                               &IsFast) ||
20419       !IsFast)
20420     return SDValue();
20421 
20422   SDValue NewPtr = TLI.getVectorElementPointer(DAG, OriginalLoad->getBasePtr(),
20423                                                InVecVT, EltNo);
20424 
20425   // We are replacing a vector load with a scalar load. The new load must have
20426   // identical memory op ordering to the original.
20427   SDValue Load;
20428   if (ResultVT.bitsGT(VecEltVT)) {
20429     // If the result type of vextract is wider than the load, then issue an
20430     // extending load instead.
20431     ISD::LoadExtType ExtType =
20432         TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT) ? ISD::ZEXTLOAD
20433                                                               : ISD::EXTLOAD;
20434     Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
20435                           NewPtr, MPI, VecEltVT, Alignment,
20436                           OriginalLoad->getMemOperand()->getFlags(),
20437                           OriginalLoad->getAAInfo());
20438     DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
20439   } else {
20440     // The result type is narrower or the same width as the vector element
20441     Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
20442                        Alignment, OriginalLoad->getMemOperand()->getFlags(),
20443                        OriginalLoad->getAAInfo());
20444     DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
20445     if (ResultVT.bitsLT(VecEltVT))
20446       Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
20447     else
20448       Load = DAG.getBitcast(ResultVT, Load);
20449   }
20450   ++OpsNarrowed;
20451   return Load;
20452 }
20453 
20454 /// Transform a vector binary operation into a scalar binary operation by moving
20455 /// the math/logic after an extract element of a vector.
scalarizeExtractedBinop(SDNode * ExtElt,SelectionDAG & DAG,bool LegalOperations)20456 static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
20457                                        bool LegalOperations) {
20458   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20459   SDValue Vec = ExtElt->getOperand(0);
20460   SDValue Index = ExtElt->getOperand(1);
20461   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
20462   if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
20463       Vec->getNumValues() != 1)
20464     return SDValue();
20465 
20466   // Targets may want to avoid this to prevent an expensive register transfer.
20467   if (!TLI.shouldScalarizeBinop(Vec))
20468     return SDValue();
20469 
20470   // Extracting an element of a vector constant is constant-folded, so this
20471   // transform is just replacing a vector op with a scalar op while moving the
20472   // extract.
20473   SDValue Op0 = Vec.getOperand(0);
20474   SDValue Op1 = Vec.getOperand(1);
20475   APInt SplatVal;
20476   if (isAnyConstantBuildVector(Op0, true) ||
20477       ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
20478       isAnyConstantBuildVector(Op1, true) ||
20479       ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
20480     // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
20481     // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
20482     SDLoc DL(ExtElt);
20483     EVT VT = ExtElt->getValueType(0);
20484     SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
20485     SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
20486     return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
20487   }
20488 
20489   return SDValue();
20490 }
20491 
20492 // Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
20493 // recursively analyse all of it's users. and try to model themselves as
20494 // bit sequence extractions. If all of them agree on the new, narrower element
20495 // type, and all of them can be modelled as ISD::EXTRACT_VECTOR_ELT's of that
20496 // new element type, do so now.
20497 // This is mainly useful to recover from legalization that scalarized
20498 // the vector as wide elements, but tries to rebuild it with narrower elements.
20499 //
20500 // Some more nodes could be modelled if that helps cover interesting patterns.
refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode * N)20501 bool DAGCombiner::refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(
20502     SDNode *N) {
20503   // We perform this optimization post type-legalization because
20504   // the type-legalizer often scalarizes integer-promoted vectors.
20505   // Performing this optimization before may cause legalizaton cycles.
20506   if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
20507     return false;
20508 
20509   // TODO: Add support for big-endian.
20510   if (DAG.getDataLayout().isBigEndian())
20511     return false;
20512 
20513   SDValue VecOp = N->getOperand(0);
20514   EVT VecVT = VecOp.getValueType();
20515   assert(!VecVT.isScalableVector() && "Only for fixed vectors.");
20516 
20517   // We must start with a constant extraction index.
20518   auto *IndexC = dyn_cast<ConstantSDNode>(N->getOperand(1));
20519   if (!IndexC)
20520     return false;
20521 
20522   assert(IndexC->getZExtValue() < VecVT.getVectorNumElements() &&
20523          "Original ISD::EXTRACT_VECTOR_ELT is undefinend?");
20524 
20525   // TODO: deal with the case of implicit anyext of the extraction.
20526   unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
20527   EVT ScalarVT = N->getValueType(0);
20528   if (VecVT.getScalarType() != ScalarVT)
20529     return false;
20530 
20531   // TODO: deal with the cases other than everything being integer-typed.
20532   if (!ScalarVT.isScalarInteger())
20533     return false;
20534 
20535   struct Entry {
20536     SDNode *Producer;
20537 
20538     // Which bits of VecOp does it contain?
20539     unsigned BitPos;
20540     int NumBits;
20541     // NOTE: the actual width of \p Producer may be wider than NumBits!
20542 
20543     Entry(Entry &&) = default;
20544     Entry(SDNode *Producer_, unsigned BitPos_, int NumBits_)
20545         : Producer(Producer_), BitPos(BitPos_), NumBits(NumBits_) {}
20546 
20547     Entry() = delete;
20548     Entry(const Entry &) = delete;
20549     Entry &operator=(const Entry &) = delete;
20550     Entry &operator=(Entry &&) = delete;
20551   };
20552   SmallVector<Entry, 32> Worklist;
20553   SmallVector<Entry, 32> Leafs;
20554 
20555   // We start at the "root" ISD::EXTRACT_VECTOR_ELT.
20556   Worklist.emplace_back(N, /*BitPos=*/VecEltBitWidth * IndexC->getZExtValue(),
20557                         /*NumBits=*/VecEltBitWidth);
20558 
20559   while (!Worklist.empty()) {
20560     Entry E = Worklist.pop_back_val();
20561     // Does the node not even use any of the VecOp bits?
20562     if (!(E.NumBits > 0 && E.BitPos < VecVT.getSizeInBits() &&
20563           E.BitPos + E.NumBits <= VecVT.getSizeInBits()))
20564       return false; // Let's allow the other combines clean this up first.
20565     // Did we fail to model any of the users of the Producer?
20566     bool ProducerIsLeaf = false;
20567     // Look at each user of this Producer.
20568     for (SDNode *User : E.Producer->uses()) {
20569       switch (User->getOpcode()) {
20570       // TODO: support ISD::BITCAST
20571       // TODO: support ISD::ANY_EXTEND
20572       // TODO: support ISD::ZERO_EXTEND
20573       // TODO: support ISD::SIGN_EXTEND
20574       case ISD::TRUNCATE:
20575         // Truncation simply means we keep position, but extract less bits.
20576         Worklist.emplace_back(User, E.BitPos,
20577                               /*NumBits=*/User->getValueSizeInBits(0));
20578         break;
20579       // TODO: support ISD::SRA
20580       // TODO: support ISD::SHL
20581       case ISD::SRL:
20582         // We should be shifting the Producer by a constant amount.
20583         if (auto *ShAmtC = dyn_cast<ConstantSDNode>(User->getOperand(1));
20584             User->getOperand(0).getNode() == E.Producer && ShAmtC) {
20585           // Logical right-shift means that we start extraction later,
20586           // but stop it at the same position we did previously.
20587           unsigned ShAmt = ShAmtC->getZExtValue();
20588           Worklist.emplace_back(User, E.BitPos + ShAmt, E.NumBits - ShAmt);
20589           break;
20590         }
20591         [[fallthrough]];
20592       default:
20593         // We can not model this user of the Producer.
20594         // Which means the current Producer will be a ISD::EXTRACT_VECTOR_ELT.
20595         ProducerIsLeaf = true;
20596         // Profitability check: all users that we can not model
20597         //                      must be ISD::BUILD_VECTOR's.
20598         if (User->getOpcode() != ISD::BUILD_VECTOR)
20599           return false;
20600         break;
20601       }
20602     }
20603     if (ProducerIsLeaf)
20604       Leafs.emplace_back(std::move(E));
20605   }
20606 
20607   unsigned NewVecEltBitWidth = Leafs.front().NumBits;
20608 
20609   // If we are still at the same element granularity, give up,
20610   if (NewVecEltBitWidth == VecEltBitWidth)
20611     return false;
20612 
20613   // The vector width must be a multiple of the new element width.
20614   if (VecVT.getSizeInBits() % NewVecEltBitWidth != 0)
20615     return false;
20616 
20617   // All leafs must agree on the new element width.
20618   // All leafs must not expect any "padding" bits ontop of that width.
20619   // All leafs must start extraction from multiple of that width.
20620   if (!all_of(Leafs, [NewVecEltBitWidth](const Entry &E) {
20621         return (unsigned)E.NumBits == NewVecEltBitWidth &&
20622                E.Producer->getValueSizeInBits(0) == NewVecEltBitWidth &&
20623                E.BitPos % NewVecEltBitWidth == 0;
20624       }))
20625     return false;
20626 
20627   EVT NewScalarVT = EVT::getIntegerVT(*DAG.getContext(), NewVecEltBitWidth);
20628   EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewScalarVT,
20629                                   VecVT.getSizeInBits() / NewVecEltBitWidth);
20630 
20631   if (LegalTypes &&
20632       !(TLI.isTypeLegal(NewScalarVT) && TLI.isTypeLegal(NewVecVT)))
20633     return false;
20634 
20635   if (LegalOperations &&
20636       !(TLI.isOperationLegalOrCustom(ISD::BITCAST, NewVecVT) &&
20637         TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, NewVecVT)))
20638     return false;
20639 
20640   SDValue NewVecOp = DAG.getBitcast(NewVecVT, VecOp);
20641   for (const Entry &E : Leafs) {
20642     SDLoc DL(E.Producer);
20643     unsigned NewIndex = E.BitPos / NewVecEltBitWidth;
20644     assert(NewIndex < NewVecVT.getVectorNumElements() &&
20645            "Creating out-of-bounds ISD::EXTRACT_VECTOR_ELT?");
20646     SDValue V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, NewScalarVT, NewVecOp,
20647                             DAG.getVectorIdxConstant(NewIndex, DL));
20648     CombineTo(E.Producer, V);
20649   }
20650 
20651   return true;
20652 }
20653 
visitEXTRACT_VECTOR_ELT(SDNode * N)20654 SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
20655   SDValue VecOp = N->getOperand(0);
20656   SDValue Index = N->getOperand(1);
20657   EVT ScalarVT = N->getValueType(0);
20658   EVT VecVT = VecOp.getValueType();
20659   if (VecOp.isUndef())
20660     return DAG.getUNDEF(ScalarVT);
20661 
20662   // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
20663   //
20664   // This only really matters if the index is non-constant since other combines
20665   // on the constant elements already work.
20666   SDLoc DL(N);
20667   if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
20668       Index == VecOp.getOperand(2)) {
20669     SDValue Elt = VecOp.getOperand(1);
20670     return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
20671   }
20672 
20673   // (vextract (scalar_to_vector val, 0) -> val
20674   if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
20675     // Only 0'th element of SCALAR_TO_VECTOR is defined.
20676     if (DAG.isKnownNeverZero(Index))
20677       return DAG.getUNDEF(ScalarVT);
20678 
20679     // Check if the result type doesn't match the inserted element type. A
20680     // SCALAR_TO_VECTOR may truncate the inserted element and the
20681     // EXTRACT_VECTOR_ELT may widen the extracted vector.
20682     SDValue InOp = VecOp.getOperand(0);
20683     if (InOp.getValueType() != ScalarVT) {
20684       assert(InOp.getValueType().isInteger() && ScalarVT.isInteger() &&
20685              InOp.getValueType().bitsGT(ScalarVT));
20686       return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, InOp);
20687     }
20688     return InOp;
20689   }
20690 
20691   // extract_vector_elt of out-of-bounds element -> UNDEF
20692   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
20693   if (IndexC && VecVT.isFixedLengthVector() &&
20694       IndexC->getAPIntValue().uge(VecVT.getVectorNumElements()))
20695     return DAG.getUNDEF(ScalarVT);
20696 
20697   // extract_vector_elt(freeze(x)), idx -> freeze(extract_vector_elt(x)), idx
20698   if (VecOp.hasOneUse() && VecOp.getOpcode() == ISD::FREEZE) {
20699     return DAG.getFreeze(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT,
20700                                      VecOp.getOperand(0), Index));
20701   }
20702 
20703   // extract_vector_elt (build_vector x, y), 1 -> y
20704   if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
20705        VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
20706       TLI.isTypeLegal(VecVT) &&
20707       (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) {
20708     assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
20709             VecVT.isFixedLengthVector()) &&
20710            "BUILD_VECTOR used for scalable vectors");
20711     unsigned IndexVal =
20712         VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
20713     SDValue Elt = VecOp.getOperand(IndexVal);
20714     EVT InEltVT = Elt.getValueType();
20715 
20716     // Sometimes build_vector's scalar input types do not match result type.
20717     if (ScalarVT == InEltVT)
20718       return Elt;
20719 
20720     // TODO: It may be useful to truncate if free if the build_vector implicitly
20721     // converts.
20722   }
20723 
20724   if (SDValue BO = scalarizeExtractedBinop(N, DAG, LegalOperations))
20725     return BO;
20726 
20727   if (VecVT.isScalableVector())
20728     return SDValue();
20729 
20730   // All the code from this point onwards assumes fixed width vectors, but it's
20731   // possible that some of the combinations could be made to work for scalable
20732   // vectors too.
20733   unsigned NumElts = VecVT.getVectorNumElements();
20734   unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
20735 
20736   // TODO: These transforms should not require the 'hasOneUse' restriction, but
20737   // there are regressions on multiple targets without it. We can end up with a
20738   // mess of scalar and vector code if we reduce only part of the DAG to scalar.
20739   if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
20740       VecOp.hasOneUse()) {
20741     // The vector index of the LSBs of the source depend on the endian-ness.
20742     bool IsLE = DAG.getDataLayout().isLittleEndian();
20743     unsigned ExtractIndex = IndexC->getZExtValue();
20744     // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
20745     unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
20746     SDValue BCSrc = VecOp.getOperand(0);
20747     if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
20748       return DAG.getAnyExtOrTrunc(BCSrc, DL, ScalarVT);
20749 
20750     if (LegalTypes && BCSrc.getValueType().isInteger() &&
20751         BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
20752       // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
20753       // trunc i64 X to i32
20754       SDValue X = BCSrc.getOperand(0);
20755       assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
20756              "Extract element and scalar to vector can't change element type "
20757              "from FP to integer.");
20758       unsigned XBitWidth = X.getValueSizeInBits();
20759       BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
20760 
20761       // An extract element return value type can be wider than its vector
20762       // operand element type. In that case, the high bits are undefined, so
20763       // it's possible that we may need to extend rather than truncate.
20764       if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
20765         assert(XBitWidth % VecEltBitWidth == 0 &&
20766                "Scalar bitwidth must be a multiple of vector element bitwidth");
20767         return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
20768       }
20769     }
20770   }
20771 
20772   // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
20773   // We only perform this optimization before the op legalization phase because
20774   // we may introduce new vector instructions which are not backed by TD
20775   // patterns. For example on AVX, extracting elements from a wide vector
20776   // without using extract_subvector. However, if we can find an underlying
20777   // scalar value, then we can always use that.
20778   if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
20779     auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
20780     // Find the new index to extract from.
20781     int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
20782 
20783     // Extracting an undef index is undef.
20784     if (OrigElt == -1)
20785       return DAG.getUNDEF(ScalarVT);
20786 
20787     // Select the right vector half to extract from.
20788     SDValue SVInVec;
20789     if (OrigElt < (int)NumElts) {
20790       SVInVec = VecOp.getOperand(0);
20791     } else {
20792       SVInVec = VecOp.getOperand(1);
20793       OrigElt -= NumElts;
20794     }
20795 
20796     if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
20797       SDValue InOp = SVInVec.getOperand(OrigElt);
20798       if (InOp.getValueType() != ScalarVT) {
20799         assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
20800         InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
20801       }
20802 
20803       return InOp;
20804     }
20805 
20806     // FIXME: We should handle recursing on other vector shuffles and
20807     // scalar_to_vector here as well.
20808 
20809     if (!LegalOperations ||
20810         // FIXME: Should really be just isOperationLegalOrCustom.
20811         TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
20812         TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
20813       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
20814                          DAG.getVectorIdxConstant(OrigElt, DL));
20815     }
20816   }
20817 
20818   // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
20819   // simplify it based on the (valid) extraction indices.
20820   if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) {
20821         return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20822                Use->getOperand(0) == VecOp &&
20823                isa<ConstantSDNode>(Use->getOperand(1));
20824       })) {
20825     APInt DemandedElts = APInt::getZero(NumElts);
20826     for (SDNode *Use : VecOp->uses()) {
20827       auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1));
20828       if (CstElt->getAPIntValue().ult(NumElts))
20829         DemandedElts.setBit(CstElt->getZExtValue());
20830     }
20831     if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
20832       // We simplified the vector operand of this extract element. If this
20833       // extract is not dead, visit it again so it is folded properly.
20834       if (N->getOpcode() != ISD::DELETED_NODE)
20835         AddToWorklist(N);
20836       return SDValue(N, 0);
20837     }
20838     APInt DemandedBits = APInt::getAllOnes(VecEltBitWidth);
20839     if (SimplifyDemandedBits(VecOp, DemandedBits, DemandedElts, true)) {
20840       // We simplified the vector operand of this extract element. If this
20841       // extract is not dead, visit it again so it is folded properly.
20842       if (N->getOpcode() != ISD::DELETED_NODE)
20843         AddToWorklist(N);
20844       return SDValue(N, 0);
20845     }
20846   }
20847 
20848   if (refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(N))
20849     return SDValue(N, 0);
20850 
20851   // Everything under here is trying to match an extract of a loaded value.
20852   // If the result of load has to be truncated, then it's not necessarily
20853   // profitable.
20854   bool BCNumEltsChanged = false;
20855   EVT ExtVT = VecVT.getVectorElementType();
20856   EVT LVT = ExtVT;
20857   if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
20858     return SDValue();
20859 
20860   if (VecOp.getOpcode() == ISD::BITCAST) {
20861     // Don't duplicate a load with other uses.
20862     if (!VecOp.hasOneUse())
20863       return SDValue();
20864 
20865     EVT BCVT = VecOp.getOperand(0).getValueType();
20866     if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
20867       return SDValue();
20868     if (NumElts != BCVT.getVectorNumElements())
20869       BCNumEltsChanged = true;
20870     VecOp = VecOp.getOperand(0);
20871     ExtVT = BCVT.getVectorElementType();
20872   }
20873 
20874   // extract (vector load $addr), i --> load $addr + i * size
20875   if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
20876       ISD::isNormalLoad(VecOp.getNode()) &&
20877       !Index->hasPredecessor(VecOp.getNode())) {
20878     auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
20879     if (VecLoad && VecLoad->isSimple())
20880       return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
20881   }
20882 
20883   // Perform only after legalization to ensure build_vector / vector_shuffle
20884   // optimizations have already been done.
20885   if (!LegalOperations || !IndexC)
20886     return SDValue();
20887 
20888   // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
20889   // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
20890   // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
20891   int Elt = IndexC->getZExtValue();
20892   LoadSDNode *LN0 = nullptr;
20893   if (ISD::isNormalLoad(VecOp.getNode())) {
20894     LN0 = cast<LoadSDNode>(VecOp);
20895   } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
20896              VecOp.getOperand(0).getValueType() == ExtVT &&
20897              ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
20898     // Don't duplicate a load with other uses.
20899     if (!VecOp.hasOneUse())
20900       return SDValue();
20901 
20902     LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
20903   }
20904   if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
20905     // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
20906     // =>
20907     // (load $addr+1*size)
20908 
20909     // Don't duplicate a load with other uses.
20910     if (!VecOp.hasOneUse())
20911       return SDValue();
20912 
20913     // If the bit convert changed the number of elements, it is unsafe
20914     // to examine the mask.
20915     if (BCNumEltsChanged)
20916       return SDValue();
20917 
20918     // Select the input vector, guarding against out of range extract vector.
20919     int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
20920     VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
20921 
20922     if (VecOp.getOpcode() == ISD::BITCAST) {
20923       // Don't duplicate a load with other uses.
20924       if (!VecOp.hasOneUse())
20925         return SDValue();
20926 
20927       VecOp = VecOp.getOperand(0);
20928     }
20929     if (ISD::isNormalLoad(VecOp.getNode())) {
20930       LN0 = cast<LoadSDNode>(VecOp);
20931       Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
20932       Index = DAG.getConstant(Elt, DL, Index.getValueType());
20933     }
20934   } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
20935              VecVT.getVectorElementType() == ScalarVT &&
20936              (!LegalTypes ||
20937               TLI.isTypeLegal(
20938                   VecOp.getOperand(0).getValueType().getVectorElementType()))) {
20939     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
20940     //      -> extract_vector_elt a, 0
20941     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
20942     //      -> extract_vector_elt a, 1
20943     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
20944     //      -> extract_vector_elt b, 0
20945     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
20946     //      -> extract_vector_elt b, 1
20947     SDLoc SL(N);
20948     EVT ConcatVT = VecOp.getOperand(0).getValueType();
20949     unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
20950     SDValue NewIdx = DAG.getConstant(Elt % ConcatNumElts, SL,
20951                                      Index.getValueType());
20952 
20953     SDValue ConcatOp = VecOp.getOperand(Elt / ConcatNumElts);
20954     SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL,
20955                               ConcatVT.getVectorElementType(),
20956                               ConcatOp, NewIdx);
20957     return DAG.getNode(ISD::BITCAST, SL, ScalarVT, Elt);
20958   }
20959 
20960   // Make sure we found a non-volatile load and the extractelement is
20961   // the only use.
20962   if (!LN0 || !LN0->hasNUsesOfValue(1,0) || !LN0->isSimple())
20963     return SDValue();
20964 
20965   // If Idx was -1 above, Elt is going to be -1, so just return undef.
20966   if (Elt == -1)
20967     return DAG.getUNDEF(LVT);
20968 
20969   return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
20970 }
20971 
20972 // Simplify (build_vec (ext )) to (bitcast (build_vec ))
reduceBuildVecExtToExtBuildVec(SDNode * N)20973 SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
20974   // We perform this optimization post type-legalization because
20975   // the type-legalizer often scalarizes integer-promoted vectors.
20976   // Performing this optimization before may create bit-casts which
20977   // will be type-legalized to complex code sequences.
20978   // We perform this optimization only before the operation legalizer because we
20979   // may introduce illegal operations.
20980   if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
20981     return SDValue();
20982 
20983   unsigned NumInScalars = N->getNumOperands();
20984   SDLoc DL(N);
20985   EVT VT = N->getValueType(0);
20986 
20987   // Check to see if this is a BUILD_VECTOR of a bunch of values
20988   // which come from any_extend or zero_extend nodes. If so, we can create
20989   // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
20990   // optimizations. We do not handle sign-extend because we can't fill the sign
20991   // using shuffles.
20992   EVT SourceType = MVT::Other;
20993   bool AllAnyExt = true;
20994 
20995   for (unsigned i = 0; i != NumInScalars; ++i) {
20996     SDValue In = N->getOperand(i);
20997     // Ignore undef inputs.
20998     if (In.isUndef()) continue;
20999 
21000     bool AnyExt  = In.getOpcode() == ISD::ANY_EXTEND;
21001     bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
21002 
21003     // Abort if the element is not an extension.
21004     if (!ZeroExt && !AnyExt) {
21005       SourceType = MVT::Other;
21006       break;
21007     }
21008 
21009     // The input is a ZeroExt or AnyExt. Check the original type.
21010     EVT InTy = In.getOperand(0).getValueType();
21011 
21012     // Check that all of the widened source types are the same.
21013     if (SourceType == MVT::Other)
21014       // First time.
21015       SourceType = InTy;
21016     else if (InTy != SourceType) {
21017       // Multiple income types. Abort.
21018       SourceType = MVT::Other;
21019       break;
21020     }
21021 
21022     // Check if all of the extends are ANY_EXTENDs.
21023     AllAnyExt &= AnyExt;
21024   }
21025 
21026   // In order to have valid types, all of the inputs must be extended from the
21027   // same source type and all of the inputs must be any or zero extend.
21028   // Scalar sizes must be a power of two.
21029   EVT OutScalarTy = VT.getScalarType();
21030   bool ValidTypes = SourceType != MVT::Other &&
21031                  isPowerOf2_32(OutScalarTy.getSizeInBits()) &&
21032                  isPowerOf2_32(SourceType.getSizeInBits());
21033 
21034   // Create a new simpler BUILD_VECTOR sequence which other optimizations can
21035   // turn into a single shuffle instruction.
21036   if (!ValidTypes)
21037     return SDValue();
21038 
21039   // If we already have a splat buildvector, then don't fold it if it means
21040   // introducing zeros.
21041   if (!AllAnyExt && DAG.isSplatValue(SDValue(N, 0), /*AllowUndefs*/ true))
21042     return SDValue();
21043 
21044   bool isLE = DAG.getDataLayout().isLittleEndian();
21045   unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
21046   assert(ElemRatio > 1 && "Invalid element size ratio");
21047   SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType):
21048                                DAG.getConstant(0, DL, SourceType);
21049 
21050   unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
21051   SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
21052 
21053   // Populate the new build_vector
21054   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
21055     SDValue Cast = N->getOperand(i);
21056     assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
21057             Cast.getOpcode() == ISD::ZERO_EXTEND ||
21058             Cast.isUndef()) && "Invalid cast opcode");
21059     SDValue In;
21060     if (Cast.isUndef())
21061       In = DAG.getUNDEF(SourceType);
21062     else
21063       In = Cast->getOperand(0);
21064     unsigned Index = isLE ? (i * ElemRatio) :
21065                             (i * ElemRatio + (ElemRatio - 1));
21066 
21067     assert(Index < Ops.size() && "Invalid index");
21068     Ops[Index] = In;
21069   }
21070 
21071   // The type of the new BUILD_VECTOR node.
21072   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SourceType, NewBVElems);
21073   assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
21074          "Invalid vector size");
21075   // Check if the new vector type is legal.
21076   if (!isTypeLegal(VecVT) ||
21077       (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) &&
21078        TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)))
21079     return SDValue();
21080 
21081   // Make the new BUILD_VECTOR.
21082   SDValue BV = DAG.getBuildVector(VecVT, DL, Ops);
21083 
21084   // The new BUILD_VECTOR node has the potential to be further optimized.
21085   AddToWorklist(BV.getNode());
21086   // Bitcast to the desired type.
21087   return DAG.getBitcast(VT, BV);
21088 }
21089 
21090 // Simplify (build_vec (trunc $1)
21091 //                     (trunc (srl $1 half-width))
21092 //                     (trunc (srl $1 (2 * half-width))))
21093 // to (bitcast $1)
reduceBuildVecTruncToBitCast(SDNode * N)21094 SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
21095   assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
21096 
21097   // Only for little endian
21098   if (!DAG.getDataLayout().isLittleEndian())
21099     return SDValue();
21100 
21101   SDLoc DL(N);
21102   EVT VT = N->getValueType(0);
21103   EVT OutScalarTy = VT.getScalarType();
21104   uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
21105 
21106   // Only for power of two types to be sure that bitcast works well
21107   if (!isPowerOf2_64(ScalarTypeBitsize))
21108     return SDValue();
21109 
21110   unsigned NumInScalars = N->getNumOperands();
21111 
21112   // Look through bitcasts
21113   auto PeekThroughBitcast = [](SDValue Op) {
21114     if (Op.getOpcode() == ISD::BITCAST)
21115       return Op.getOperand(0);
21116     return Op;
21117   };
21118 
21119   // The source value where all the parts are extracted.
21120   SDValue Src;
21121   for (unsigned i = 0; i != NumInScalars; ++i) {
21122     SDValue In = PeekThroughBitcast(N->getOperand(i));
21123     // Ignore undef inputs.
21124     if (In.isUndef()) continue;
21125 
21126     if (In.getOpcode() != ISD::TRUNCATE)
21127       return SDValue();
21128 
21129     In = PeekThroughBitcast(In.getOperand(0));
21130 
21131     if (In.getOpcode() != ISD::SRL) {
21132       // For now only build_vec without shuffling, handle shifts here in the
21133       // future.
21134       if (i != 0)
21135         return SDValue();
21136 
21137       Src = In;
21138     } else {
21139       // In is SRL
21140       SDValue part = PeekThroughBitcast(In.getOperand(0));
21141 
21142       if (!Src) {
21143         Src = part;
21144       } else if (Src != part) {
21145         // Vector parts do not stem from the same variable
21146         return SDValue();
21147       }
21148 
21149       SDValue ShiftAmtVal = In.getOperand(1);
21150       if (!isa<ConstantSDNode>(ShiftAmtVal))
21151         return SDValue();
21152 
21153       uint64_t ShiftAmt = In.getConstantOperandVal(1);
21154 
21155       // The extracted value is not extracted at the right position
21156       if (ShiftAmt != i * ScalarTypeBitsize)
21157         return SDValue();
21158     }
21159   }
21160 
21161   // Only cast if the size is the same
21162   if (Src.getValueType().getSizeInBits() != VT.getSizeInBits())
21163     return SDValue();
21164 
21165   return DAG.getBitcast(VT, Src);
21166 }
21167 
createBuildVecShuffle(const SDLoc & DL,SDNode * N,ArrayRef<int> VectorMask,SDValue VecIn1,SDValue VecIn2,unsigned LeftIdx,bool DidSplitVec)21168 SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
21169                                            ArrayRef<int> VectorMask,
21170                                            SDValue VecIn1, SDValue VecIn2,
21171                                            unsigned LeftIdx, bool DidSplitVec) {
21172   SDValue ZeroIdx = DAG.getVectorIdxConstant(0, DL);
21173 
21174   EVT VT = N->getValueType(0);
21175   EVT InVT1 = VecIn1.getValueType();
21176   EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
21177 
21178   unsigned NumElems = VT.getVectorNumElements();
21179   unsigned ShuffleNumElems = NumElems;
21180 
21181   // If we artificially split a vector in two already, then the offsets in the
21182   // operands will all be based off of VecIn1, even those in VecIn2.
21183   unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
21184 
21185   uint64_t VTSize = VT.getFixedSizeInBits();
21186   uint64_t InVT1Size = InVT1.getFixedSizeInBits();
21187   uint64_t InVT2Size = InVT2.getFixedSizeInBits();
21188 
21189   assert(InVT2Size <= InVT1Size &&
21190          "Inputs must be sorted to be in non-increasing vector size order.");
21191 
21192   // We can't generate a shuffle node with mismatched input and output types.
21193   // Try to make the types match the type of the output.
21194   if (InVT1 != VT || InVT2 != VT) {
21195     if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
21196       // If the output vector length is a multiple of both input lengths,
21197       // we can concatenate them and pad the rest with undefs.
21198       unsigned NumConcats = VTSize / InVT1Size;
21199       assert(NumConcats >= 2 && "Concat needs at least two inputs!");
21200       SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1));
21201       ConcatOps[0] = VecIn1;
21202       ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1);
21203       VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
21204       VecIn2 = SDValue();
21205     } else if (InVT1Size == VTSize * 2) {
21206       if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems))
21207         return SDValue();
21208 
21209       if (!VecIn2.getNode()) {
21210         // If we only have one input vector, and it's twice the size of the
21211         // output, split it in two.
21212         VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1,
21213                              DAG.getVectorIdxConstant(NumElems, DL));
21214         VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, ZeroIdx);
21215         // Since we now have shorter input vectors, adjust the offset of the
21216         // second vector's start.
21217         Vec2Offset = NumElems;
21218       } else {
21219         assert(InVT2Size <= InVT1Size &&
21220                "Second input is not going to be larger than the first one.");
21221 
21222         // VecIn1 is wider than the output, and we have another, possibly
21223         // smaller input. Pad the smaller input with undefs, shuffle at the
21224         // input vector width, and extract the output.
21225         // The shuffle type is different than VT, so check legality again.
21226         if (LegalOperations &&
21227             !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
21228           return SDValue();
21229 
21230         // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
21231         // lower it back into a BUILD_VECTOR. So if the inserted type is
21232         // illegal, don't even try.
21233         if (InVT1 != InVT2) {
21234           if (!TLI.isTypeLegal(InVT2))
21235             return SDValue();
21236           VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
21237                                DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
21238         }
21239         ShuffleNumElems = NumElems * 2;
21240       }
21241     } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
21242       SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
21243       ConcatOps[0] = VecIn2;
21244       VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
21245     } else if (InVT1Size / VTSize > 1 && InVT1Size % VTSize == 0) {
21246       if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems) ||
21247           !TLI.isTypeLegal(InVT1) || !TLI.isTypeLegal(InVT2))
21248         return SDValue();
21249       // If dest vector has less than two elements, then use shuffle and extract
21250       // from larger regs will cost even more.
21251       if (VT.getVectorNumElements() <= 2 || !VecIn2.getNode())
21252         return SDValue();
21253       assert(InVT2Size <= InVT1Size &&
21254              "Second input is not going to be larger than the first one.");
21255 
21256       // VecIn1 is wider than the output, and we have another, possibly
21257       // smaller input. Pad the smaller input with undefs, shuffle at the
21258       // input vector width, and extract the output.
21259       // The shuffle type is different than VT, so check legality again.
21260       if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
21261         return SDValue();
21262 
21263       if (InVT1 != InVT2) {
21264         VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
21265                              DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
21266       }
21267       ShuffleNumElems = InVT1Size / VTSize * NumElems;
21268     } else {
21269       // TODO: Support cases where the length mismatch isn't exactly by a
21270       // factor of 2.
21271       // TODO: Move this check upwards, so that if we have bad type
21272       // mismatches, we don't create any DAG nodes.
21273       return SDValue();
21274     }
21275   }
21276 
21277   // Initialize mask to undef.
21278   SmallVector<int, 8> Mask(ShuffleNumElems, -1);
21279 
21280   // Only need to run up to the number of elements actually used, not the
21281   // total number of elements in the shuffle - if we are shuffling a wider
21282   // vector, the high lanes should be set to undef.
21283   for (unsigned i = 0; i != NumElems; ++i) {
21284     if (VectorMask[i] <= 0)
21285       continue;
21286 
21287     unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1);
21288     if (VectorMask[i] == (int)LeftIdx) {
21289       Mask[i] = ExtIndex;
21290     } else if (VectorMask[i] == (int)LeftIdx + 1) {
21291       Mask[i] = Vec2Offset + ExtIndex;
21292     }
21293   }
21294 
21295   // The type the input vectors may have changed above.
21296   InVT1 = VecIn1.getValueType();
21297 
21298   // If we already have a VecIn2, it should have the same type as VecIn1.
21299   // If we don't, get an undef/zero vector of the appropriate type.
21300   VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1);
21301   assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
21302 
21303   SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask);
21304   if (ShuffleNumElems > NumElems)
21305     Shuffle = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Shuffle, ZeroIdx);
21306 
21307   return Shuffle;
21308 }
21309 
reduceBuildVecToShuffleWithZero(SDNode * BV,SelectionDAG & DAG)21310 static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
21311   assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
21312 
21313   // First, determine where the build vector is not undef.
21314   // TODO: We could extend this to handle zero elements as well as undefs.
21315   int NumBVOps = BV->getNumOperands();
21316   int ZextElt = -1;
21317   for (int i = 0; i != NumBVOps; ++i) {
21318     SDValue Op = BV->getOperand(i);
21319     if (Op.isUndef())
21320       continue;
21321     if (ZextElt == -1)
21322       ZextElt = i;
21323     else
21324       return SDValue();
21325   }
21326   // Bail out if there's no non-undef element.
21327   if (ZextElt == -1)
21328     return SDValue();
21329 
21330   // The build vector contains some number of undef elements and exactly
21331   // one other element. That other element must be a zero-extended scalar
21332   // extracted from a vector at a constant index to turn this into a shuffle.
21333   // Also, require that the build vector does not implicitly truncate/extend
21334   // its elements.
21335   // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
21336   EVT VT = BV->getValueType(0);
21337   SDValue Zext = BV->getOperand(ZextElt);
21338   if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
21339       Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
21340       !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) ||
21341       Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
21342     return SDValue();
21343 
21344   // The zero-extend must be a multiple of the source size, and we must be
21345   // building a vector of the same size as the source of the extract element.
21346   SDValue Extract = Zext.getOperand(0);
21347   unsigned DestSize = Zext.getValueSizeInBits();
21348   unsigned SrcSize = Extract.getValueSizeInBits();
21349   if (DestSize % SrcSize != 0 ||
21350       Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits())
21351     return SDValue();
21352 
21353   // Create a shuffle mask that will combine the extracted element with zeros
21354   // and undefs.
21355   int ZextRatio = DestSize / SrcSize;
21356   int NumMaskElts = NumBVOps * ZextRatio;
21357   SmallVector<int, 32> ShufMask(NumMaskElts, -1);
21358   for (int i = 0; i != NumMaskElts; ++i) {
21359     if (i / ZextRatio == ZextElt) {
21360       // The low bits of the (potentially translated) extracted element map to
21361       // the source vector. The high bits map to zero. We will use a zero vector
21362       // as the 2nd source operand of the shuffle, so use the 1st element of
21363       // that vector (mask value is number-of-elements) for the high bits.
21364       if (i % ZextRatio == 0)
21365         ShufMask[i] = Extract.getConstantOperandVal(1);
21366       else
21367         ShufMask[i] = NumMaskElts;
21368     }
21369 
21370     // Undef elements of the build vector remain undef because we initialize
21371     // the shuffle mask with -1.
21372   }
21373 
21374   // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
21375   // bitcast (shuffle V, ZeroVec, VectorMask)
21376   SDLoc DL(BV);
21377   EVT VecVT = Extract.getOperand(0).getValueType();
21378   SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
21379   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21380   SDValue Shuf = TLI.buildLegalVectorShuffle(VecVT, DL, Extract.getOperand(0),
21381                                              ZeroVec, ShufMask, DAG);
21382   if (!Shuf)
21383     return SDValue();
21384   return DAG.getBitcast(VT, Shuf);
21385 }
21386 
21387 // FIXME: promote to STLExtras.
21388 template <typename R, typename T>
getFirstIndexOf(R && Range,const T & Val)21389 static auto getFirstIndexOf(R &&Range, const T &Val) {
21390   auto I = find(Range, Val);
21391   if (I == Range.end())
21392     return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
21393   return std::distance(Range.begin(), I);
21394 }
21395 
21396 // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
21397 // operations. If the types of the vectors we're extracting from allow it,
21398 // turn this into a vector_shuffle node.
reduceBuildVecToShuffle(SDNode * N)21399 SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
21400   SDLoc DL(N);
21401   EVT VT = N->getValueType(0);
21402 
21403   // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
21404   if (!isTypeLegal(VT))
21405     return SDValue();
21406 
21407   if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG))
21408     return V;
21409 
21410   // May only combine to shuffle after legalize if shuffle is legal.
21411   if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT))
21412     return SDValue();
21413 
21414   bool UsesZeroVector = false;
21415   unsigned NumElems = N->getNumOperands();
21416 
21417   // Record, for each element of the newly built vector, which input vector
21418   // that element comes from. -1 stands for undef, 0 for the zero vector,
21419   // and positive values for the input vectors.
21420   // VectorMask maps each element to its vector number, and VecIn maps vector
21421   // numbers to their initial SDValues.
21422 
21423   SmallVector<int, 8> VectorMask(NumElems, -1);
21424   SmallVector<SDValue, 8> VecIn;
21425   VecIn.push_back(SDValue());
21426 
21427   for (unsigned i = 0; i != NumElems; ++i) {
21428     SDValue Op = N->getOperand(i);
21429 
21430     if (Op.isUndef())
21431       continue;
21432 
21433     // See if we can use a blend with a zero vector.
21434     // TODO: Should we generalize this to a blend with an arbitrary constant
21435     // vector?
21436     if (isNullConstant(Op) || isNullFPConstant(Op)) {
21437       UsesZeroVector = true;
21438       VectorMask[i] = 0;
21439       continue;
21440     }
21441 
21442     // Not an undef or zero. If the input is something other than an
21443     // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
21444     if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
21445         !isa<ConstantSDNode>(Op.getOperand(1)))
21446       return SDValue();
21447     SDValue ExtractedFromVec = Op.getOperand(0);
21448 
21449     if (ExtractedFromVec.getValueType().isScalableVector())
21450       return SDValue();
21451 
21452     const APInt &ExtractIdx = Op.getConstantOperandAPInt(1);
21453     if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements()))
21454       return SDValue();
21455 
21456     // All inputs must have the same element type as the output.
21457     if (VT.getVectorElementType() !=
21458         ExtractedFromVec.getValueType().getVectorElementType())
21459       return SDValue();
21460 
21461     // Have we seen this input vector before?
21462     // The vectors are expected to be tiny (usually 1 or 2 elements), so using
21463     // a map back from SDValues to numbers isn't worth it.
21464     int Idx = getFirstIndexOf(VecIn, ExtractedFromVec);
21465     if (Idx == -1) { // A new source vector?
21466       Idx = VecIn.size();
21467       VecIn.push_back(ExtractedFromVec);
21468     }
21469 
21470     VectorMask[i] = Idx;
21471   }
21472 
21473   // If we didn't find at least one input vector, bail out.
21474   if (VecIn.size() < 2)
21475     return SDValue();
21476 
21477   // If all the Operands of BUILD_VECTOR extract from same
21478   // vector, then split the vector efficiently based on the maximum
21479   // vector access index and adjust the VectorMask and
21480   // VecIn accordingly.
21481   bool DidSplitVec = false;
21482   if (VecIn.size() == 2) {
21483     unsigned MaxIndex = 0;
21484     unsigned NearestPow2 = 0;
21485     SDValue Vec = VecIn.back();
21486     EVT InVT = Vec.getValueType();
21487     SmallVector<unsigned, 8> IndexVec(NumElems, 0);
21488 
21489     for (unsigned i = 0; i < NumElems; i++) {
21490       if (VectorMask[i] <= 0)
21491         continue;
21492       unsigned Index = N->getOperand(i).getConstantOperandVal(1);
21493       IndexVec[i] = Index;
21494       MaxIndex = std::max(MaxIndex, Index);
21495     }
21496 
21497     NearestPow2 = PowerOf2Ceil(MaxIndex);
21498     if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
21499         NumElems * 2 < NearestPow2) {
21500       unsigned SplitSize = NearestPow2 / 2;
21501       EVT SplitVT = EVT::getVectorVT(*DAG.getContext(),
21502                                      InVT.getVectorElementType(), SplitSize);
21503       if (TLI.isTypeLegal(SplitVT) &&
21504           SplitSize + SplitVT.getVectorNumElements() <=
21505               InVT.getVectorNumElements()) {
21506         SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
21507                                      DAG.getVectorIdxConstant(SplitSize, DL));
21508         SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
21509                                      DAG.getVectorIdxConstant(0, DL));
21510         VecIn.pop_back();
21511         VecIn.push_back(VecIn1);
21512         VecIn.push_back(VecIn2);
21513         DidSplitVec = true;
21514 
21515         for (unsigned i = 0; i < NumElems; i++) {
21516           if (VectorMask[i] <= 0)
21517             continue;
21518           VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
21519         }
21520       }
21521     }
21522   }
21523 
21524   // Sort input vectors by decreasing vector element count,
21525   // while preserving the relative order of equally-sized vectors.
21526   // Note that we keep the first "implicit zero vector as-is.
21527   SmallVector<SDValue, 8> SortedVecIn(VecIn);
21528   llvm::stable_sort(MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
21529                     [](const SDValue &a, const SDValue &b) {
21530                       return a.getValueType().getVectorNumElements() >
21531                              b.getValueType().getVectorNumElements();
21532                     });
21533 
21534   // We now also need to rebuild the VectorMask, because it referenced element
21535   // order in VecIn, and we just sorted them.
21536   for (int &SourceVectorIndex : VectorMask) {
21537     if (SourceVectorIndex <= 0)
21538       continue;
21539     unsigned Idx = getFirstIndexOf(SortedVecIn, VecIn[SourceVectorIndex]);
21540     assert(Idx > 0 && Idx < SortedVecIn.size() &&
21541            VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
21542     SourceVectorIndex = Idx;
21543   }
21544 
21545   VecIn = std::move(SortedVecIn);
21546 
21547   // TODO: Should this fire if some of the input vectors has illegal type (like
21548   // it does now), or should we let legalization run its course first?
21549 
21550   // Shuffle phase:
21551   // Take pairs of vectors, and shuffle them so that the result has elements
21552   // from these vectors in the correct places.
21553   // For example, given:
21554   // t10: i32 = extract_vector_elt t1, Constant:i64<0>
21555   // t11: i32 = extract_vector_elt t2, Constant:i64<0>
21556   // t12: i32 = extract_vector_elt t3, Constant:i64<0>
21557   // t13: i32 = extract_vector_elt t1, Constant:i64<1>
21558   // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
21559   // We will generate:
21560   // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
21561   // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
21562   SmallVector<SDValue, 4> Shuffles;
21563   for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
21564     unsigned LeftIdx = 2 * In + 1;
21565     SDValue VecLeft = VecIn[LeftIdx];
21566     SDValue VecRight =
21567         (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
21568 
21569     if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft,
21570                                                 VecRight, LeftIdx, DidSplitVec))
21571       Shuffles.push_back(Shuffle);
21572     else
21573       return SDValue();
21574   }
21575 
21576   // If we need the zero vector as an "ingredient" in the blend tree, add it
21577   // to the list of shuffles.
21578   if (UsesZeroVector)
21579     Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT)
21580                                       : DAG.getConstantFP(0.0, DL, VT));
21581 
21582   // If we only have one shuffle, we're done.
21583   if (Shuffles.size() == 1)
21584     return Shuffles[0];
21585 
21586   // Update the vector mask to point to the post-shuffle vectors.
21587   for (int &Vec : VectorMask)
21588     if (Vec == 0)
21589       Vec = Shuffles.size() - 1;
21590     else
21591       Vec = (Vec - 1) / 2;
21592 
21593   // More than one shuffle. Generate a binary tree of blends, e.g. if from
21594   // the previous step we got the set of shuffles t10, t11, t12, t13, we will
21595   // generate:
21596   // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
21597   // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
21598   // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
21599   // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
21600   // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
21601   // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
21602   // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
21603 
21604   // Make sure the initial size of the shuffle list is even.
21605   if (Shuffles.size() % 2)
21606     Shuffles.push_back(DAG.getUNDEF(VT));
21607 
21608   for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
21609     if (CurSize % 2) {
21610       Shuffles[CurSize] = DAG.getUNDEF(VT);
21611       CurSize++;
21612     }
21613     for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
21614       int Left = 2 * In;
21615       int Right = 2 * In + 1;
21616       SmallVector<int, 8> Mask(NumElems, -1);
21617       SDValue L = Shuffles[Left];
21618       ArrayRef<int> LMask;
21619       bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE &&
21620                            L.use_empty() && L.getOperand(1).isUndef() &&
21621                            L.getOperand(0).getValueType() == L.getValueType();
21622       if (IsLeftShuffle) {
21623         LMask = cast<ShuffleVectorSDNode>(L.getNode())->getMask();
21624         L = L.getOperand(0);
21625       }
21626       SDValue R = Shuffles[Right];
21627       ArrayRef<int> RMask;
21628       bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE &&
21629                             R.use_empty() && R.getOperand(1).isUndef() &&
21630                             R.getOperand(0).getValueType() == R.getValueType();
21631       if (IsRightShuffle) {
21632         RMask = cast<ShuffleVectorSDNode>(R.getNode())->getMask();
21633         R = R.getOperand(0);
21634       }
21635       for (unsigned I = 0; I != NumElems; ++I) {
21636         if (VectorMask[I] == Left) {
21637           Mask[I] = I;
21638           if (IsLeftShuffle)
21639             Mask[I] = LMask[I];
21640           VectorMask[I] = In;
21641         } else if (VectorMask[I] == Right) {
21642           Mask[I] = I + NumElems;
21643           if (IsRightShuffle)
21644             Mask[I] = RMask[I] + NumElems;
21645           VectorMask[I] = In;
21646         }
21647       }
21648 
21649       Shuffles[In] = DAG.getVectorShuffle(VT, DL, L, R, Mask);
21650     }
21651   }
21652   return Shuffles[0];
21653 }
21654 
21655 // Try to turn a build vector of zero extends of extract vector elts into a
21656 // a vector zero extend and possibly an extract subvector.
21657 // TODO: Support sign extend?
21658 // TODO: Allow undef elements?
convertBuildVecZextToZext(SDNode * N)21659 SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
21660   if (LegalOperations)
21661     return SDValue();
21662 
21663   EVT VT = N->getValueType(0);
21664 
21665   bool FoundZeroExtend = false;
21666   SDValue Op0 = N->getOperand(0);
21667   auto checkElem = [&](SDValue Op) -> int64_t {
21668     unsigned Opc = Op.getOpcode();
21669     FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
21670     if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
21671         Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
21672         Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0))
21673       if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1)))
21674         return C->getZExtValue();
21675     return -1;
21676   };
21677 
21678   // Make sure the first element matches
21679   // (zext (extract_vector_elt X, C))
21680   // Offset must be a constant multiple of the
21681   // known-minimum vector length of the result type.
21682   int64_t Offset = checkElem(Op0);
21683   if (Offset < 0 || (Offset % VT.getVectorNumElements()) != 0)
21684     return SDValue();
21685 
21686   unsigned NumElems = N->getNumOperands();
21687   SDValue In = Op0.getOperand(0).getOperand(0);
21688   EVT InSVT = In.getValueType().getScalarType();
21689   EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems);
21690 
21691   // Don't create an illegal input type after type legalization.
21692   if (LegalTypes && !TLI.isTypeLegal(InVT))
21693     return SDValue();
21694 
21695   // Ensure all the elements come from the same vector and are adjacent.
21696   for (unsigned i = 1; i != NumElems; ++i) {
21697     if ((Offset + i) != checkElem(N->getOperand(i)))
21698       return SDValue();
21699   }
21700 
21701   SDLoc DL(N);
21702   In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In,
21703                    Op0.getOperand(0).getOperand(1));
21704   return DAG.getNode(FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
21705                      VT, In);
21706 }
21707 
21708 // If this is a very simple BUILD_VECTOR with first element being a ZERO_EXTEND,
21709 // and all other elements being constant zero's, granularize the BUILD_VECTOR's
21710 // element width, absorbing the ZERO_EXTEND, turning it into a constant zero op.
21711 // This patten can appear during legalization.
21712 //
21713 // NOTE: This can be generalized to allow more than a single
21714 //       non-constant-zero op, UNDEF's, and to be KnownBits-based,
convertBuildVecZextToBuildVecWithZeros(SDNode * N)21715 SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
21716   // Don't run this after legalization. Targets may have other preferences.
21717   if (Level >= AfterLegalizeDAG)
21718     return SDValue();
21719 
21720   // FIXME: support big-endian.
21721   if (DAG.getDataLayout().isBigEndian())
21722     return SDValue();
21723 
21724   EVT VT = N->getValueType(0);
21725   EVT OpVT = N->getOperand(0).getValueType();
21726   assert(!VT.isScalableVector() && "Encountered scalable BUILD_VECTOR?");
21727 
21728   EVT OpIntVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
21729 
21730   if (!TLI.isTypeLegal(OpIntVT) ||
21731       (LegalOperations && !TLI.isOperationLegalOrCustom(ISD::BITCAST, OpIntVT)))
21732     return SDValue();
21733 
21734   unsigned EltBitwidth = VT.getScalarSizeInBits();
21735   // NOTE: the actual width of operands may be wider than that!
21736 
21737   // Analyze all operands of this BUILD_VECTOR. What is the largest number of
21738   // active bits they all have? We'll want to truncate them all to that width.
21739   unsigned ActiveBits = 0;
21740   APInt KnownZeroOps(VT.getVectorNumElements(), 0);
21741   for (auto I : enumerate(N->ops())) {
21742     SDValue Op = I.value();
21743     // FIXME: support UNDEF elements?
21744     if (auto *Cst = dyn_cast<ConstantSDNode>(Op)) {
21745       unsigned OpActiveBits =
21746           Cst->getAPIntValue().trunc(EltBitwidth).getActiveBits();
21747       if (OpActiveBits == 0) {
21748         KnownZeroOps.setBit(I.index());
21749         continue;
21750       }
21751       // Profitability check: don't allow non-zero constant operands.
21752       return SDValue();
21753     }
21754     // Profitability check: there must only be a single non-zero operand,
21755     // and it must be the first operand of the BUILD_VECTOR.
21756     if (I.index() != 0)
21757       return SDValue();
21758     // The operand must be a zero-extension itself.
21759     // FIXME: this could be generalized to known leading zeros check.
21760     if (Op.getOpcode() != ISD::ZERO_EXTEND)
21761       return SDValue();
21762     unsigned CurrActiveBits =
21763         Op.getOperand(0).getValueSizeInBits().getFixedValue();
21764     assert(!ActiveBits && "Already encountered non-constant-zero operand?");
21765     ActiveBits = CurrActiveBits;
21766     // We want to at least halve the element size.
21767     if (2 * ActiveBits > EltBitwidth)
21768       return SDValue();
21769   }
21770 
21771   // This BUILD_VECTOR must have at least one non-constant-zero operand.
21772   if (ActiveBits == 0)
21773     return SDValue();
21774 
21775   // We have EltBitwidth bits, the *minimal* chunk size is ActiveBits,
21776   // into how many chunks can we split our element width?
21777   EVT NewScalarIntVT, NewIntVT;
21778   std::optional<unsigned> Factor;
21779   // We can split the element into at least two chunks, but not into more
21780   // than |_ EltBitwidth / ActiveBits _| chunks. Find a largest split factor
21781   // for which the element width is a multiple of it,
21782   // and the resulting types/operations on that chunk width are legal.
21783   assert(2 * ActiveBits <= EltBitwidth &&
21784          "We know that half or less bits of the element are active.");
21785   for (unsigned Scale = EltBitwidth / ActiveBits; Scale >= 2; --Scale) {
21786     if (EltBitwidth % Scale != 0)
21787       continue;
21788     unsigned ChunkBitwidth = EltBitwidth / Scale;
21789     assert(ChunkBitwidth >= ActiveBits && "As per starting point.");
21790     NewScalarIntVT = EVT::getIntegerVT(*DAG.getContext(), ChunkBitwidth);
21791     NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewScalarIntVT,
21792                                 Scale * N->getNumOperands());
21793     if (!TLI.isTypeLegal(NewScalarIntVT) || !TLI.isTypeLegal(NewIntVT) ||
21794         (LegalOperations &&
21795          !(TLI.isOperationLegalOrCustom(ISD::TRUNCATE, NewScalarIntVT) &&
21796            TLI.isOperationLegalOrCustom(ISD::BUILD_VECTOR, NewIntVT))))
21797       continue;
21798     Factor = Scale;
21799     break;
21800   }
21801   if (!Factor)
21802     return SDValue();
21803 
21804   SDLoc DL(N);
21805   SDValue ZeroOp = DAG.getConstant(0, DL, NewScalarIntVT);
21806 
21807   // Recreate the BUILD_VECTOR, with elements now being Factor times smaller.
21808   SmallVector<SDValue, 16> NewOps;
21809   NewOps.reserve(NewIntVT.getVectorNumElements());
21810   for (auto I : enumerate(N->ops())) {
21811     SDValue Op = I.value();
21812     assert(!Op.isUndef() && "FIXME: after allowing UNDEF's, handle them here.");
21813     unsigned SrcOpIdx = I.index();
21814     if (KnownZeroOps[SrcOpIdx]) {
21815       NewOps.append(*Factor, ZeroOp);
21816       continue;
21817     }
21818     Op = DAG.getBitcast(OpIntVT, Op);
21819     Op = DAG.getNode(ISD::TRUNCATE, DL, NewScalarIntVT, Op);
21820     NewOps.emplace_back(Op);
21821     NewOps.append(*Factor - 1, ZeroOp);
21822   }
21823   assert(NewOps.size() == NewIntVT.getVectorNumElements());
21824   SDValue NewBV = DAG.getBuildVector(NewIntVT, DL, NewOps);
21825   NewBV = DAG.getBitcast(VT, NewBV);
21826   return NewBV;
21827 }
21828 
visitBUILD_VECTOR(SDNode * N)21829 SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
21830   EVT VT = N->getValueType(0);
21831 
21832   // A vector built entirely of undefs is undef.
21833   if (ISD::allOperandsUndef(N))
21834     return DAG.getUNDEF(VT);
21835 
21836   // If this is a splat of a bitcast from another vector, change to a
21837   // concat_vector.
21838   // For example:
21839   //   (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
21840   //     (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
21841   //
21842   // If X is a build_vector itself, the concat can become a larger build_vector.
21843   // TODO: Maybe this is useful for non-splat too?
21844   if (!LegalOperations) {
21845     if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) {
21846       Splat = peekThroughBitcasts(Splat);
21847       EVT SrcVT = Splat.getValueType();
21848       if (SrcVT.isVector()) {
21849         unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
21850         EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
21851                                      SrcVT.getVectorElementType(), NumElts);
21852         if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
21853           SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
21854           SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N),
21855                                        NewVT, Ops);
21856           return DAG.getBitcast(VT, Concat);
21857         }
21858       }
21859     }
21860   }
21861 
21862   // Check if we can express BUILD VECTOR via subvector extract.
21863   if (!LegalTypes && (N->getNumOperands() > 1)) {
21864     SDValue Op0 = N->getOperand(0);
21865     auto checkElem = [&](SDValue Op) -> uint64_t {
21866       if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
21867           (Op0.getOperand(0) == Op.getOperand(0)))
21868         if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
21869           return CNode->getZExtValue();
21870       return -1;
21871     };
21872 
21873     int Offset = checkElem(Op0);
21874     for (unsigned i = 0; i < N->getNumOperands(); ++i) {
21875       if (Offset + i != checkElem(N->getOperand(i))) {
21876         Offset = -1;
21877         break;
21878       }
21879     }
21880 
21881     if ((Offset == 0) &&
21882         (Op0.getOperand(0).getValueType() == N->getValueType(0)))
21883       return Op0.getOperand(0);
21884     if ((Offset != -1) &&
21885         ((Offset % N->getValueType(0).getVectorNumElements()) ==
21886          0)) // IDX must be multiple of output size.
21887       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
21888                          Op0.getOperand(0), Op0.getOperand(1));
21889   }
21890 
21891   if (SDValue V = convertBuildVecZextToZext(N))
21892     return V;
21893 
21894   if (SDValue V = convertBuildVecZextToBuildVecWithZeros(N))
21895     return V;
21896 
21897   if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
21898     return V;
21899 
21900   if (SDValue V = reduceBuildVecTruncToBitCast(N))
21901     return V;
21902 
21903   if (SDValue V = reduceBuildVecToShuffle(N))
21904     return V;
21905 
21906   // A splat of a single element is a SPLAT_VECTOR if supported on the target.
21907   // Do this late as some of the above may replace the splat.
21908   if (TLI.getOperationAction(ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
21909     if (SDValue V = cast<BuildVectorSDNode>(N)->getSplatValue()) {
21910       assert(!V.isUndef() && "Splat of undef should have been handled earlier");
21911       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V);
21912     }
21913 
21914   return SDValue();
21915 }
21916 
combineConcatVectorOfScalars(SDNode * N,SelectionDAG & DAG)21917 static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
21918   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21919   EVT OpVT = N->getOperand(0).getValueType();
21920 
21921   // If the operands are legal vectors, leave them alone.
21922   if (TLI.isTypeLegal(OpVT))
21923     return SDValue();
21924 
21925   SDLoc DL(N);
21926   EVT VT = N->getValueType(0);
21927   SmallVector<SDValue, 8> Ops;
21928 
21929   EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
21930   SDValue ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
21931 
21932   // Keep track of what we encounter.
21933   bool AnyInteger = false;
21934   bool AnyFP = false;
21935   for (const SDValue &Op : N->ops()) {
21936     if (ISD::BITCAST == Op.getOpcode() &&
21937         !Op.getOperand(0).getValueType().isVector())
21938       Ops.push_back(Op.getOperand(0));
21939     else if (ISD::UNDEF == Op.getOpcode())
21940       Ops.push_back(ScalarUndef);
21941     else
21942       return SDValue();
21943 
21944     // Note whether we encounter an integer or floating point scalar.
21945     // If it's neither, bail out, it could be something weird like x86mmx.
21946     EVT LastOpVT = Ops.back().getValueType();
21947     if (LastOpVT.isFloatingPoint())
21948       AnyFP = true;
21949     else if (LastOpVT.isInteger())
21950       AnyInteger = true;
21951     else
21952       return SDValue();
21953   }
21954 
21955   // If any of the operands is a floating point scalar bitcast to a vector,
21956   // use floating point types throughout, and bitcast everything.
21957   // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
21958   if (AnyFP) {
21959     SVT = EVT::getFloatingPointVT(OpVT.getSizeInBits());
21960     ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
21961     if (AnyInteger) {
21962       for (SDValue &Op : Ops) {
21963         if (Op.getValueType() == SVT)
21964           continue;
21965         if (Op.isUndef())
21966           Op = ScalarUndef;
21967         else
21968           Op = DAG.getBitcast(SVT, Op);
21969       }
21970     }
21971   }
21972 
21973   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT,
21974                                VT.getSizeInBits() / SVT.getSizeInBits());
21975   return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
21976 }
21977 
21978 // Attempt to merge nested concat_vectors/undefs.
21979 // Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
21980 //  --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
combineConcatVectorOfConcatVectors(SDNode * N,SelectionDAG & DAG)21981 static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
21982                                                   SelectionDAG &DAG) {
21983   EVT VT = N->getValueType(0);
21984 
21985   // Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
21986   EVT SubVT;
21987   SDValue FirstConcat;
21988   for (const SDValue &Op : N->ops()) {
21989     if (Op.isUndef())
21990       continue;
21991     if (Op.getOpcode() != ISD::CONCAT_VECTORS)
21992       return SDValue();
21993     if (!FirstConcat) {
21994       SubVT = Op.getOperand(0).getValueType();
21995       if (!DAG.getTargetLoweringInfo().isTypeLegal(SubVT))
21996         return SDValue();
21997       FirstConcat = Op;
21998       continue;
21999     }
22000     if (SubVT != Op.getOperand(0).getValueType())
22001       return SDValue();
22002   }
22003   assert(FirstConcat && "Concat of all-undefs found");
22004 
22005   SmallVector<SDValue> ConcatOps;
22006   for (const SDValue &Op : N->ops()) {
22007     if (Op.isUndef()) {
22008       ConcatOps.append(FirstConcat->getNumOperands(), DAG.getUNDEF(SubVT));
22009       continue;
22010     }
22011     ConcatOps.append(Op->op_begin(), Op->op_end());
22012   }
22013   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, ConcatOps);
22014 }
22015 
22016 // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
22017 // operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
22018 // most two distinct vectors the same size as the result, attempt to turn this
22019 // into a legal shuffle.
combineConcatVectorOfExtracts(SDNode * N,SelectionDAG & DAG)22020 static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
22021   EVT VT = N->getValueType(0);
22022   EVT OpVT = N->getOperand(0).getValueType();
22023 
22024   // We currently can't generate an appropriate shuffle for a scalable vector.
22025   if (VT.isScalableVector())
22026     return SDValue();
22027 
22028   int NumElts = VT.getVectorNumElements();
22029   int NumOpElts = OpVT.getVectorNumElements();
22030 
22031   SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
22032   SmallVector<int, 8> Mask;
22033 
22034   for (SDValue Op : N->ops()) {
22035     Op = peekThroughBitcasts(Op);
22036 
22037     // UNDEF nodes convert to UNDEF shuffle mask values.
22038     if (Op.isUndef()) {
22039       Mask.append((unsigned)NumOpElts, -1);
22040       continue;
22041     }
22042 
22043     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
22044       return SDValue();
22045 
22046     // What vector are we extracting the subvector from and at what index?
22047     SDValue ExtVec = Op.getOperand(0);
22048     int ExtIdx = Op.getConstantOperandVal(1);
22049 
22050     // We want the EVT of the original extraction to correctly scale the
22051     // extraction index.
22052     EVT ExtVT = ExtVec.getValueType();
22053     ExtVec = peekThroughBitcasts(ExtVec);
22054 
22055     // UNDEF nodes convert to UNDEF shuffle mask values.
22056     if (ExtVec.isUndef()) {
22057       Mask.append((unsigned)NumOpElts, -1);
22058       continue;
22059     }
22060 
22061     // Ensure that we are extracting a subvector from a vector the same
22062     // size as the result.
22063     if (ExtVT.getSizeInBits() != VT.getSizeInBits())
22064       return SDValue();
22065 
22066     // Scale the subvector index to account for any bitcast.
22067     int NumExtElts = ExtVT.getVectorNumElements();
22068     if (0 == (NumExtElts % NumElts))
22069       ExtIdx /= (NumExtElts / NumElts);
22070     else if (0 == (NumElts % NumExtElts))
22071       ExtIdx *= (NumElts / NumExtElts);
22072     else
22073       return SDValue();
22074 
22075     // At most we can reference 2 inputs in the final shuffle.
22076     if (SV0.isUndef() || SV0 == ExtVec) {
22077       SV0 = ExtVec;
22078       for (int i = 0; i != NumOpElts; ++i)
22079         Mask.push_back(i + ExtIdx);
22080     } else if (SV1.isUndef() || SV1 == ExtVec) {
22081       SV1 = ExtVec;
22082       for (int i = 0; i != NumOpElts; ++i)
22083         Mask.push_back(i + ExtIdx + NumElts);
22084     } else {
22085       return SDValue();
22086     }
22087   }
22088 
22089   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22090   return TLI.buildLegalVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
22091                                      DAG.getBitcast(VT, SV1), Mask, DAG);
22092 }
22093 
combineConcatVectorOfCasts(SDNode * N,SelectionDAG & DAG)22094 static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
22095   unsigned CastOpcode = N->getOperand(0).getOpcode();
22096   switch (CastOpcode) {
22097   case ISD::SINT_TO_FP:
22098   case ISD::UINT_TO_FP:
22099   case ISD::FP_TO_SINT:
22100   case ISD::FP_TO_UINT:
22101     // TODO: Allow more opcodes?
22102     //  case ISD::BITCAST:
22103     //  case ISD::TRUNCATE:
22104     //  case ISD::ZERO_EXTEND:
22105     //  case ISD::SIGN_EXTEND:
22106     //  case ISD::FP_EXTEND:
22107     break;
22108   default:
22109     return SDValue();
22110   }
22111 
22112   EVT SrcVT = N->getOperand(0).getOperand(0).getValueType();
22113   if (!SrcVT.isVector())
22114     return SDValue();
22115 
22116   // All operands of the concat must be the same kind of cast from the same
22117   // source type.
22118   SmallVector<SDValue, 4> SrcOps;
22119   for (SDValue Op : N->ops()) {
22120     if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
22121         Op.getOperand(0).getValueType() != SrcVT)
22122       return SDValue();
22123     SrcOps.push_back(Op.getOperand(0));
22124   }
22125 
22126   // The wider cast must be supported by the target. This is unusual because
22127   // the operation support type parameter depends on the opcode. In addition,
22128   // check the other type in the cast to make sure this is really legal.
22129   EVT VT = N->getValueType(0);
22130   EVT SrcEltVT = SrcVT.getVectorElementType();
22131   ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
22132   EVT ConcatSrcVT = EVT::getVectorVT(*DAG.getContext(), SrcEltVT, NumElts);
22133   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22134   switch (CastOpcode) {
22135   case ISD::SINT_TO_FP:
22136   case ISD::UINT_TO_FP:
22137     if (!TLI.isOperationLegalOrCustom(CastOpcode, ConcatSrcVT) ||
22138         !TLI.isTypeLegal(VT))
22139       return SDValue();
22140     break;
22141   case ISD::FP_TO_SINT:
22142   case ISD::FP_TO_UINT:
22143     if (!TLI.isOperationLegalOrCustom(CastOpcode, VT) ||
22144         !TLI.isTypeLegal(ConcatSrcVT))
22145       return SDValue();
22146     break;
22147   default:
22148     llvm_unreachable("Unexpected cast opcode");
22149   }
22150 
22151   // concat (cast X), (cast Y)... -> cast (concat X, Y...)
22152   SDLoc DL(N);
22153   SDValue NewConcat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatSrcVT, SrcOps);
22154   return DAG.getNode(CastOpcode, DL, VT, NewConcat);
22155 }
22156 
22157 // See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of
22158 // the operands is a SHUFFLE_VECTOR, and all other operands are also operands
22159 // to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR.
combineConcatVectorOfShuffleAndItsOperands(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalTypes,bool LegalOperations)22160 static SDValue combineConcatVectorOfShuffleAndItsOperands(
22161     SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
22162     bool LegalOperations) {
22163   EVT VT = N->getValueType(0);
22164   EVT OpVT = N->getOperand(0).getValueType();
22165   if (VT.isScalableVector())
22166     return SDValue();
22167 
22168   // For now, only allow simple 2-operand concatenations.
22169   if (N->getNumOperands() != 2)
22170     return SDValue();
22171 
22172   // Don't create illegal types/shuffles when not allowed to.
22173   if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
22174       (LegalOperations &&
22175        !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT)))
22176     return SDValue();
22177 
22178   // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them,
22179   // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us,
22180   // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR,
22181   // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!).
22182   // (4) and for now, the SHUFFLE_VECTOR must be unary.
22183   ShuffleVectorSDNode *SVN = nullptr;
22184   for (SDValue Op : N->ops()) {
22185     if (auto *CurSVN = dyn_cast<ShuffleVectorSDNode>(Op);
22186         CurSVN && CurSVN->getOperand(1).isUndef() && N->isOnlyUserOf(CurSVN) &&
22187         all_of(N->ops(), [CurSVN](SDValue Op) {
22188           // FIXME: can we allow UNDEF operands?
22189           return !Op.isUndef() &&
22190                  (Op.getNode() == CurSVN || is_contained(CurSVN->ops(), Op));
22191         })) {
22192       SVN = CurSVN;
22193       break;
22194     }
22195   }
22196   if (!SVN)
22197     return SDValue();
22198 
22199   // We are going to pad the shuffle operands, so any indice, that was picking
22200   // from the second operand, must be adjusted.
22201   SmallVector<int, 16> AdjustedMask;
22202   AdjustedMask.reserve(SVN->getMask().size());
22203   assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!");
22204   append_range(AdjustedMask, SVN->getMask());
22205 
22206   // Identity masks for the operands of the (padded) shuffle.
22207   SmallVector<int, 32> IdentityMask(2 * OpVT.getVectorNumElements());
22208   MutableArrayRef<int> FirstShufOpIdentityMask =
22209       MutableArrayRef<int>(IdentityMask)
22210           .take_front(OpVT.getVectorNumElements());
22211   MutableArrayRef<int> SecondShufOpIdentityMask =
22212       MutableArrayRef<int>(IdentityMask).take_back(OpVT.getVectorNumElements());
22213   std::iota(FirstShufOpIdentityMask.begin(), FirstShufOpIdentityMask.end(), 0);
22214   std::iota(SecondShufOpIdentityMask.begin(), SecondShufOpIdentityMask.end(),
22215             VT.getVectorNumElements());
22216 
22217   // New combined shuffle mask.
22218   SmallVector<int, 32> Mask;
22219   Mask.reserve(VT.getVectorNumElements());
22220   for (SDValue Op : N->ops()) {
22221     assert(!Op.isUndef() && "Not expecting to concatenate UNDEF.");
22222     if (Op.getNode() == SVN) {
22223       append_range(Mask, AdjustedMask);
22224       continue;
22225     }
22226     if (Op == SVN->getOperand(0)) {
22227       append_range(Mask, FirstShufOpIdentityMask);
22228       continue;
22229     }
22230     if (Op == SVN->getOperand(1)) {
22231       append_range(Mask, SecondShufOpIdentityMask);
22232       continue;
22233     }
22234     llvm_unreachable("Unexpected operand!");
22235   }
22236 
22237   // Don't create illegal shuffle masks.
22238   if (!TLI.isShuffleMaskLegal(Mask, VT))
22239     return SDValue();
22240 
22241   // Pad the shuffle operands with UNDEF.
22242   SDLoc dl(N);
22243   std::array<SDValue, 2> ShufOps;
22244   for (auto I : zip(SVN->ops(), ShufOps)) {
22245     SDValue ShufOp = std::get<0>(I);
22246     SDValue &NewShufOp = std::get<1>(I);
22247     if (ShufOp.isUndef())
22248       NewShufOp = DAG.getUNDEF(VT);
22249     else {
22250       SmallVector<SDValue, 2> ShufOpParts(N->getNumOperands(),
22251                                           DAG.getUNDEF(OpVT));
22252       ShufOpParts[0] = ShufOp;
22253       NewShufOp = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, ShufOpParts);
22254     }
22255   }
22256   // Finally, create the new wide shuffle.
22257   return DAG.getVectorShuffle(VT, dl, ShufOps[0], ShufOps[1], Mask);
22258 }
22259 
visitCONCAT_VECTORS(SDNode * N)22260 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
22261   // If we only have one input vector, we don't need to do any concatenation.
22262   if (N->getNumOperands() == 1)
22263     return N->getOperand(0);
22264 
22265   // Check if all of the operands are undefs.
22266   EVT VT = N->getValueType(0);
22267   if (ISD::allOperandsUndef(N))
22268     return DAG.getUNDEF(VT);
22269 
22270   // Optimize concat_vectors where all but the first of the vectors are undef.
22271   if (all_of(drop_begin(N->ops()),
22272              [](const SDValue &Op) { return Op.isUndef(); })) {
22273     SDValue In = N->getOperand(0);
22274     assert(In.getValueType().isVector() && "Must concat vectors");
22275 
22276     // If the input is a concat_vectors, just make a larger concat by padding
22277     // with smaller undefs.
22278     if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse()) {
22279       unsigned NumOps = N->getNumOperands() * In.getNumOperands();
22280       SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end());
22281       Ops.resize(NumOps, DAG.getUNDEF(Ops[0].getValueType()));
22282       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
22283     }
22284 
22285     SDValue Scalar = peekThroughOneUseBitcasts(In);
22286 
22287     // concat_vectors(scalar_to_vector(scalar), undef) ->
22288     //     scalar_to_vector(scalar)
22289     if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
22290          Scalar.hasOneUse()) {
22291       EVT SVT = Scalar.getValueType().getVectorElementType();
22292       if (SVT == Scalar.getOperand(0).getValueType())
22293         Scalar = Scalar.getOperand(0);
22294     }
22295 
22296     // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
22297     if (!Scalar.getValueType().isVector()) {
22298       // If the bitcast type isn't legal, it might be a trunc of a legal type;
22299       // look through the trunc so we can still do the transform:
22300       //   concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
22301       if (Scalar->getOpcode() == ISD::TRUNCATE &&
22302           !TLI.isTypeLegal(Scalar.getValueType()) &&
22303           TLI.isTypeLegal(Scalar->getOperand(0).getValueType()))
22304         Scalar = Scalar->getOperand(0);
22305 
22306       EVT SclTy = Scalar.getValueType();
22307 
22308       if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
22309         return SDValue();
22310 
22311       // Bail out if the vector size is not a multiple of the scalar size.
22312       if (VT.getSizeInBits() % SclTy.getSizeInBits())
22313         return SDValue();
22314 
22315       unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
22316       if (VNTNumElms < 2)
22317         return SDValue();
22318 
22319       EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
22320       if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
22321         return SDValue();
22322 
22323       SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar);
22324       return DAG.getBitcast(VT, Res);
22325     }
22326   }
22327 
22328   // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
22329   // We have already tested above for an UNDEF only concatenation.
22330   // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
22331   // -> (BUILD_VECTOR A, B, ..., C, D, ...)
22332   auto IsBuildVectorOrUndef = [](const SDValue &Op) {
22333     return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
22334   };
22335   if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) {
22336     SmallVector<SDValue, 8> Opnds;
22337     EVT SVT = VT.getScalarType();
22338 
22339     EVT MinVT = SVT;
22340     if (!SVT.isFloatingPoint()) {
22341       // If BUILD_VECTOR are from built from integer, they may have different
22342       // operand types. Get the smallest type and truncate all operands to it.
22343       bool FoundMinVT = false;
22344       for (const SDValue &Op : N->ops())
22345         if (ISD::BUILD_VECTOR == Op.getOpcode()) {
22346           EVT OpSVT = Op.getOperand(0).getValueType();
22347           MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT;
22348           FoundMinVT = true;
22349         }
22350       assert(FoundMinVT && "Concat vector type mismatch");
22351     }
22352 
22353     for (const SDValue &Op : N->ops()) {
22354       EVT OpVT = Op.getValueType();
22355       unsigned NumElts = OpVT.getVectorNumElements();
22356 
22357       if (ISD::UNDEF == Op.getOpcode())
22358         Opnds.append(NumElts, DAG.getUNDEF(MinVT));
22359 
22360       if (ISD::BUILD_VECTOR == Op.getOpcode()) {
22361         if (SVT.isFloatingPoint()) {
22362           assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
22363           Opnds.append(Op->op_begin(), Op->op_begin() + NumElts);
22364         } else {
22365           for (unsigned i = 0; i != NumElts; ++i)
22366             Opnds.push_back(
22367                 DAG.getNode(ISD::TRUNCATE, SDLoc(N), MinVT, Op.getOperand(i)));
22368         }
22369       }
22370     }
22371 
22372     assert(VT.getVectorNumElements() == Opnds.size() &&
22373            "Concat vector type mismatch");
22374     return DAG.getBuildVector(VT, SDLoc(N), Opnds);
22375   }
22376 
22377   // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
22378   // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
22379   if (SDValue V = combineConcatVectorOfScalars(N, DAG))
22380     return V;
22381 
22382   if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
22383     // Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
22384     if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
22385       return V;
22386 
22387     // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
22388     if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
22389       return V;
22390   }
22391 
22392   if (SDValue V = combineConcatVectorOfCasts(N, DAG))
22393     return V;
22394 
22395   if (SDValue V = combineConcatVectorOfShuffleAndItsOperands(
22396           N, DAG, TLI, LegalTypes, LegalOperations))
22397     return V;
22398 
22399   // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
22400   // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
22401   // operands and look for a CONCAT operations that place the incoming vectors
22402   // at the exact same location.
22403   //
22404   // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
22405   SDValue SingleSource = SDValue();
22406   unsigned PartNumElem =
22407       N->getOperand(0).getValueType().getVectorMinNumElements();
22408 
22409   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
22410     SDValue Op = N->getOperand(i);
22411 
22412     if (Op.isUndef())
22413       continue;
22414 
22415     // Check if this is the identity extract:
22416     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
22417       return SDValue();
22418 
22419     // Find the single incoming vector for the extract_subvector.
22420     if (SingleSource.getNode()) {
22421       if (Op.getOperand(0) != SingleSource)
22422         return SDValue();
22423     } else {
22424       SingleSource = Op.getOperand(0);
22425 
22426       // Check the source type is the same as the type of the result.
22427       // If not, this concat may extend the vector, so we can not
22428       // optimize it away.
22429       if (SingleSource.getValueType() != N->getValueType(0))
22430         return SDValue();
22431     }
22432 
22433     // Check that we are reading from the identity index.
22434     unsigned IdentityIndex = i * PartNumElem;
22435     if (Op.getConstantOperandAPInt(1) != IdentityIndex)
22436       return SDValue();
22437   }
22438 
22439   if (SingleSource.getNode())
22440     return SingleSource;
22441 
22442   return SDValue();
22443 }
22444 
22445 // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
22446 // if the subvector can be sourced for free.
getSubVectorSrc(SDValue V,SDValue Index,EVT SubVT)22447 static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
22448   if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
22449       V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) {
22450     return V.getOperand(1);
22451   }
22452   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
22453   if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
22454       V.getOperand(0).getValueType() == SubVT &&
22455       (IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
22456     uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
22457     return V.getOperand(SubIdx);
22458   }
22459   return SDValue();
22460 }
22461 
narrowInsertExtractVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)22462 static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
22463                                               SelectionDAG &DAG,
22464                                               bool LegalOperations) {
22465   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22466   SDValue BinOp = Extract->getOperand(0);
22467   unsigned BinOpcode = BinOp.getOpcode();
22468   if (!TLI.isBinOp(BinOpcode) || BinOp->getNumValues() != 1)
22469     return SDValue();
22470 
22471   EVT VecVT = BinOp.getValueType();
22472   SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
22473   if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
22474     return SDValue();
22475 
22476   SDValue Index = Extract->getOperand(1);
22477   EVT SubVT = Extract->getValueType(0);
22478   if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations))
22479     return SDValue();
22480 
22481   SDValue Sub0 = getSubVectorSrc(Bop0, Index, SubVT);
22482   SDValue Sub1 = getSubVectorSrc(Bop1, Index, SubVT);
22483 
22484   // TODO: We could handle the case where only 1 operand is being inserted by
22485   //       creating an extract of the other operand, but that requires checking
22486   //       number of uses and/or costs.
22487   if (!Sub0 || !Sub1)
22488     return SDValue();
22489 
22490   // We are inserting both operands of the wide binop only to extract back
22491   // to the narrow vector size. Eliminate all of the insert/extract:
22492   // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
22493   return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1,
22494                      BinOp->getFlags());
22495 }
22496 
22497 /// If we are extracting a subvector produced by a wide binary operator try
22498 /// to use a narrow binary operator and/or avoid concatenation and extraction.
narrowExtractedVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)22499 static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
22500                                           bool LegalOperations) {
22501   // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
22502   // some of these bailouts with other transforms.
22503 
22504   if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
22505     return V;
22506 
22507   // The extract index must be a constant, so we can map it to a concat operand.
22508   auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
22509   if (!ExtractIndexC)
22510     return SDValue();
22511 
22512   // We are looking for an optionally bitcasted wide vector binary operator
22513   // feeding an extract subvector.
22514   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22515   SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
22516   unsigned BOpcode = BinOp.getOpcode();
22517   if (!TLI.isBinOp(BOpcode) || BinOp->getNumValues() != 1)
22518     return SDValue();
22519 
22520   // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
22521   // reduced to the unary fneg when it is visited, and we probably want to deal
22522   // with fneg in a target-specific way.
22523   if (BOpcode == ISD::FSUB) {
22524     auto *C = isConstOrConstSplatFP(BinOp.getOperand(0), /*AllowUndefs*/ true);
22525     if (C && C->getValueAPF().isNegZero())
22526       return SDValue();
22527   }
22528 
22529   // The binop must be a vector type, so we can extract some fraction of it.
22530   EVT WideBVT = BinOp.getValueType();
22531   // The optimisations below currently assume we are dealing with fixed length
22532   // vectors. It is possible to add support for scalable vectors, but at the
22533   // moment we've done no analysis to prove whether they are profitable or not.
22534   if (!WideBVT.isFixedLengthVector())
22535     return SDValue();
22536 
22537   EVT VT = Extract->getValueType(0);
22538   unsigned ExtractIndex = ExtractIndexC->getZExtValue();
22539   assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
22540          "Extract index is not a multiple of the vector length.");
22541 
22542   // Bail out if this is not a proper multiple width extraction.
22543   unsigned WideWidth = WideBVT.getSizeInBits();
22544   unsigned NarrowWidth = VT.getSizeInBits();
22545   if (WideWidth % NarrowWidth != 0)
22546     return SDValue();
22547 
22548   // Bail out if we are extracting a fraction of a single operation. This can
22549   // occur because we potentially looked through a bitcast of the binop.
22550   unsigned NarrowingRatio = WideWidth / NarrowWidth;
22551   unsigned WideNumElts = WideBVT.getVectorNumElements();
22552   if (WideNumElts % NarrowingRatio != 0)
22553     return SDValue();
22554 
22555   // Bail out if the target does not support a narrower version of the binop.
22556   EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
22557                                    WideNumElts / NarrowingRatio);
22558   if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT))
22559     return SDValue();
22560 
22561   // If extraction is cheap, we don't need to look at the binop operands
22562   // for concat ops. The narrow binop alone makes this transform profitable.
22563   // We can't just reuse the original extract index operand because we may have
22564   // bitcasted.
22565   unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
22566   unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
22567   if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
22568       BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
22569     // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
22570     SDLoc DL(Extract);
22571     SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL);
22572     SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
22573                             BinOp.getOperand(0), NewExtIndex);
22574     SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
22575                             BinOp.getOperand(1), NewExtIndex);
22576     SDValue NarrowBinOp =
22577         DAG.getNode(BOpcode, DL, NarrowBVT, X, Y, BinOp->getFlags());
22578     return DAG.getBitcast(VT, NarrowBinOp);
22579   }
22580 
22581   // Only handle the case where we are doubling and then halving. A larger ratio
22582   // may require more than two narrow binops to replace the wide binop.
22583   if (NarrowingRatio != 2)
22584     return SDValue();
22585 
22586   // TODO: The motivating case for this transform is an x86 AVX1 target. That
22587   // target has temptingly almost legal versions of bitwise logic ops in 256-bit
22588   // flavors, but no other 256-bit integer support. This could be extended to
22589   // handle any binop, but that may require fixing/adding other folds to avoid
22590   // codegen regressions.
22591   if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
22592     return SDValue();
22593 
22594   // We need at least one concatenation operation of a binop operand to make
22595   // this transform worthwhile. The concat must double the input vector sizes.
22596   auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
22597     if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
22598       return V.getOperand(ConcatOpNum);
22599     return SDValue();
22600   };
22601   SDValue SubVecL = GetSubVector(peekThroughBitcasts(BinOp.getOperand(0)));
22602   SDValue SubVecR = GetSubVector(peekThroughBitcasts(BinOp.getOperand(1)));
22603 
22604   if (SubVecL || SubVecR) {
22605     // If a binop operand was not the result of a concat, we must extract a
22606     // half-sized operand for our new narrow binop:
22607     // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
22608     // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
22609     // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
22610     SDLoc DL(Extract);
22611     SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL);
22612     SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
22613                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
22614                                       BinOp.getOperand(0), IndexC);
22615 
22616     SDValue Y = SubVecR ? DAG.getBitcast(NarrowBVT, SubVecR)
22617                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
22618                                       BinOp.getOperand(1), IndexC);
22619 
22620     SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
22621     return DAG.getBitcast(VT, NarrowBinOp);
22622   }
22623 
22624   return SDValue();
22625 }
22626 
22627 /// If we are extracting a subvector from a wide vector load, convert to a
22628 /// narrow load to eliminate the extraction:
22629 /// (extract_subvector (load wide vector)) --> (load narrow vector)
narrowExtractedVectorLoad(SDNode * Extract,SelectionDAG & DAG)22630 static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
22631   // TODO: Add support for big-endian. The offset calculation must be adjusted.
22632   if (DAG.getDataLayout().isBigEndian())
22633     return SDValue();
22634 
22635   auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
22636   if (!Ld || Ld->getExtensionType() || !Ld->isSimple())
22637     return SDValue();
22638 
22639   // Allow targets to opt-out.
22640   EVT VT = Extract->getValueType(0);
22641 
22642   // We can only create byte sized loads.
22643   if (!VT.isByteSized())
22644     return SDValue();
22645 
22646   unsigned Index = Extract->getConstantOperandVal(1);
22647   unsigned NumElts = VT.getVectorMinNumElements();
22648 
22649   // The definition of EXTRACT_SUBVECTOR states that the index must be a
22650   // multiple of the minimum number of elements in the result type.
22651   assert(Index % NumElts == 0 && "The extract subvector index is not a "
22652                                  "multiple of the result's element count");
22653 
22654   // It's fine to use TypeSize here as we know the offset will not be negative.
22655   TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
22656 
22657   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22658   if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT))
22659     return SDValue();
22660 
22661   // The narrow load will be offset from the base address of the old load if
22662   // we are extracting from something besides index 0 (little-endian).
22663   SDLoc DL(Extract);
22664 
22665   // TODO: Use "BaseIndexOffset" to make this more effective.
22666   SDValue NewAddr = DAG.getMemBasePlusOffset(Ld->getBasePtr(), Offset, DL);
22667 
22668   uint64_t StoreSize = MemoryLocation::getSizeOrUnknown(VT.getStoreSize());
22669   MachineFunction &MF = DAG.getMachineFunction();
22670   MachineMemOperand *MMO;
22671   if (Offset.isScalable()) {
22672     MachinePointerInfo MPI =
22673         MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
22674     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), MPI, StoreSize);
22675   } else
22676     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset.getFixedValue(),
22677                                   StoreSize);
22678 
22679   SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
22680   DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
22681   return NewLd;
22682 }
22683 
22684 /// Given  EXTRACT_SUBVECTOR(VECTOR_SHUFFLE(Op0, Op1, Mask)),
22685 /// try to produce  VECTOR_SHUFFLE(EXTRACT_SUBVECTOR(Op?, ?),
22686 ///                                EXTRACT_SUBVECTOR(Op?, ?),
22687 ///                                Mask'))
22688 /// iff it is legal and profitable to do so. Notably, the trimmed mask
22689 /// (containing only the elements that are extracted)
22690 /// must reference at most two subvectors.
foldExtractSubvectorFromShuffleVector(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)22691 static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
22692                                                      SelectionDAG &DAG,
22693                                                      const TargetLowering &TLI,
22694                                                      bool LegalOperations) {
22695   assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
22696          "Must only be called on EXTRACT_SUBVECTOR's");
22697 
22698   SDValue N0 = N->getOperand(0);
22699 
22700   // Only deal with non-scalable vectors.
22701   EVT NarrowVT = N->getValueType(0);
22702   EVT WideVT = N0.getValueType();
22703   if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
22704     return SDValue();
22705 
22706   // The operand must be a shufflevector.
22707   auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(N0);
22708   if (!WideShuffleVector)
22709     return SDValue();
22710 
22711   // The old shuffleneeds to go away.
22712   if (!WideShuffleVector->hasOneUse())
22713     return SDValue();
22714 
22715   // And the narrow shufflevector that we'll form must be legal.
22716   if (LegalOperations &&
22717       !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, NarrowVT))
22718     return SDValue();
22719 
22720   uint64_t FirstExtractedEltIdx = N->getConstantOperandVal(1);
22721   int NumEltsExtracted = NarrowVT.getVectorNumElements();
22722   assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 &&
22723          "Extract index is not a multiple of the output vector length.");
22724 
22725   int WideNumElts = WideVT.getVectorNumElements();
22726 
22727   SmallVector<int, 16> NewMask;
22728   NewMask.reserve(NumEltsExtracted);
22729   SmallSetVector<std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>, 2>
22730       DemandedSubvectors;
22731 
22732   // Try to decode the wide mask into narrow mask from at most two subvectors.
22733   for (int M : WideShuffleVector->getMask().slice(FirstExtractedEltIdx,
22734                                                   NumEltsExtracted)) {
22735     assert((M >= -1) && (M < (2 * WideNumElts)) &&
22736            "Out-of-bounds shuffle mask?");
22737 
22738     if (M < 0) {
22739       // Does not depend on operands, does not require adjustment.
22740       NewMask.emplace_back(M);
22741       continue;
22742     }
22743 
22744     // From which operand of the shuffle does this shuffle mask element pick?
22745     int WideShufOpIdx = M / WideNumElts;
22746     // Which element of that operand is picked?
22747     int OpEltIdx = M % WideNumElts;
22748 
22749     assert((OpEltIdx + WideShufOpIdx * WideNumElts) == M &&
22750            "Shuffle mask vector decomposition failure.");
22751 
22752     // And which NumEltsExtracted-sized subvector of that operand is that?
22753     int OpSubvecIdx = OpEltIdx / NumEltsExtracted;
22754     // And which element within that subvector of that operand is that?
22755     int OpEltIdxInSubvec = OpEltIdx % NumEltsExtracted;
22756 
22757     assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted) == OpEltIdx &&
22758            "Shuffle mask subvector decomposition failure.");
22759 
22760     assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted +
22761             WideShufOpIdx * WideNumElts) == M &&
22762            "Shuffle mask full decomposition failure.");
22763 
22764     SDValue Op = WideShuffleVector->getOperand(WideShufOpIdx);
22765 
22766     if (Op.isUndef()) {
22767       // Picking from an undef operand. Let's adjust mask instead.
22768       NewMask.emplace_back(-1);
22769       continue;
22770     }
22771 
22772     // Profitability check: only deal with extractions from the first subvector.
22773     if (OpSubvecIdx != 0)
22774       return SDValue();
22775 
22776     const std::pair<SDValue, int> DemandedSubvector =
22777         std::make_pair(Op, OpSubvecIdx);
22778 
22779     if (DemandedSubvectors.insert(DemandedSubvector)) {
22780       if (DemandedSubvectors.size() > 2)
22781         return SDValue(); // We can't handle more than two subvectors.
22782       // How many elements into the WideVT does this subvector start?
22783       int Index = NumEltsExtracted * OpSubvecIdx;
22784       // Bail out if the extraction isn't going to be cheap.
22785       if (!TLI.isExtractSubvectorCheap(NarrowVT, WideVT, Index))
22786         return SDValue();
22787     }
22788 
22789     // Ok, but from which operand of the new shuffle will this element pick?
22790     int NewOpIdx =
22791         getFirstIndexOf(DemandedSubvectors.getArrayRef(), DemandedSubvector);
22792     assert((NewOpIdx == 0 || NewOpIdx == 1) && "Unexpected operand index.");
22793 
22794     int AdjM = OpEltIdxInSubvec + NewOpIdx * NumEltsExtracted;
22795     NewMask.emplace_back(AdjM);
22796   }
22797   assert(NewMask.size() == (unsigned)NumEltsExtracted && "Produced bad mask.");
22798   assert(DemandedSubvectors.size() <= 2 &&
22799          "Should have ended up demanding at most two subvectors.");
22800 
22801   // Did we discover that the shuffle does not actually depend on operands?
22802   if (DemandedSubvectors.empty())
22803     return DAG.getUNDEF(NarrowVT);
22804 
22805   // We still perform the exact same EXTRACT_SUBVECTOR,  just on different
22806   // operand[s]/index[es], so there is no point in checking for it's legality.
22807 
22808   // Do not turn a legal shuffle into an illegal one.
22809   if (TLI.isShuffleMaskLegal(WideShuffleVector->getMask(), WideVT) &&
22810       !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
22811     return SDValue();
22812 
22813   SDLoc DL(N);
22814 
22815   SmallVector<SDValue, 2> NewOps;
22816   for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
22817            &DemandedSubvector : DemandedSubvectors) {
22818     // How many elements into the WideVT does this subvector start?
22819     int Index = NumEltsExtracted * DemandedSubvector.second;
22820     SDValue IndexC = DAG.getVectorIdxConstant(Index, DL);
22821     NewOps.emplace_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowVT,
22822                                     DemandedSubvector.first, IndexC));
22823   }
22824   assert((NewOps.size() == 1 || NewOps.size() == 2) &&
22825          "Should end up with either one or two ops");
22826 
22827   // If we ended up with only one operand, pad with an undef.
22828   if (NewOps.size() == 1)
22829     NewOps.emplace_back(DAG.getUNDEF(NarrowVT));
22830 
22831   return DAG.getVectorShuffle(NarrowVT, DL, NewOps[0], NewOps[1], NewMask);
22832 }
22833 
visitEXTRACT_SUBVECTOR(SDNode * N)22834 SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
22835   EVT NVT = N->getValueType(0);
22836   SDValue V = N->getOperand(0);
22837   uint64_t ExtIdx = N->getConstantOperandVal(1);
22838 
22839   // Extract from UNDEF is UNDEF.
22840   if (V.isUndef())
22841     return DAG.getUNDEF(NVT);
22842 
22843   if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
22844     if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DAG))
22845       return NarrowLoad;
22846 
22847   // Combine an extract of an extract into a single extract_subvector.
22848   // ext (ext X, C), 0 --> ext X, C
22849   if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
22850     if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
22851                                     V.getConstantOperandVal(1)) &&
22852         TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) {
22853       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT, V.getOperand(0),
22854                          V.getOperand(1));
22855     }
22856   }
22857 
22858   // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
22859   if (V.getOpcode() == ISD::SPLAT_VECTOR)
22860     if (DAG.isConstantValueOfAnyType(V.getOperand(0)) || V.hasOneUse())
22861       if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT))
22862         return DAG.getSplatVector(NVT, SDLoc(N), V.getOperand(0));
22863 
22864   // Try to move vector bitcast after extract_subv by scaling extraction index:
22865   // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
22866   if (V.getOpcode() == ISD::BITCAST &&
22867       V.getOperand(0).getValueType().isVector() &&
22868       (!LegalOperations || TLI.isOperationLegal(ISD::BITCAST, NVT))) {
22869     SDValue SrcOp = V.getOperand(0);
22870     EVT SrcVT = SrcOp.getValueType();
22871     unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
22872     unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
22873     if ((SrcNumElts % DestNumElts) == 0) {
22874       unsigned SrcDestRatio = SrcNumElts / DestNumElts;
22875       ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
22876       EVT NewExtVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getScalarType(),
22877                                       NewExtEC);
22878       if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
22879         SDLoc DL(N);
22880         SDValue NewIndex = DAG.getVectorIdxConstant(ExtIdx * SrcDestRatio, DL);
22881         SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
22882                                          V.getOperand(0), NewIndex);
22883         return DAG.getBitcast(NVT, NewExtract);
22884       }
22885     }
22886     if ((DestNumElts % SrcNumElts) == 0) {
22887       unsigned DestSrcRatio = DestNumElts / SrcNumElts;
22888       if (NVT.getVectorElementCount().isKnownMultipleOf(DestSrcRatio)) {
22889         ElementCount NewExtEC =
22890             NVT.getVectorElementCount().divideCoefficientBy(DestSrcRatio);
22891         EVT ScalarVT = SrcVT.getScalarType();
22892         if ((ExtIdx % DestSrcRatio) == 0) {
22893           SDLoc DL(N);
22894           unsigned IndexValScaled = ExtIdx / DestSrcRatio;
22895           EVT NewExtVT =
22896               EVT::getVectorVT(*DAG.getContext(), ScalarVT, NewExtEC);
22897           if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
22898             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
22899             SDValue NewExtract =
22900                 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
22901                             V.getOperand(0), NewIndex);
22902             return DAG.getBitcast(NVT, NewExtract);
22903           }
22904           if (NewExtEC.isScalar() &&
22905               TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, ScalarVT)) {
22906             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
22907             SDValue NewExtract =
22908                 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT,
22909                             V.getOperand(0), NewIndex);
22910             return DAG.getBitcast(NVT, NewExtract);
22911           }
22912         }
22913       }
22914     }
22915   }
22916 
22917   if (V.getOpcode() == ISD::CONCAT_VECTORS) {
22918     unsigned ExtNumElts = NVT.getVectorMinNumElements();
22919     EVT ConcatSrcVT = V.getOperand(0).getValueType();
22920     assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
22921            "Concat and extract subvector do not change element type");
22922     assert((ExtIdx % ExtNumElts) == 0 &&
22923            "Extract index is not a multiple of the input vector length.");
22924 
22925     unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
22926     unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
22927 
22928     // If the concatenated source types match this extract, it's a direct
22929     // simplification:
22930     // extract_subvec (concat V1, V2, ...), i --> Vi
22931     if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
22932       return V.getOperand(ConcatOpIdx);
22933 
22934     // If the concatenated source vectors are a multiple length of this extract,
22935     // then extract a fraction of one of those source vectors directly from a
22936     // concat operand. Example:
22937     //   v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
22938     //   v2i8 extract_subvec v8i8 Y, 6
22939     if (NVT.isFixedLengthVector() && ConcatSrcVT.isFixedLengthVector() &&
22940         ConcatSrcNumElts % ExtNumElts == 0) {
22941       SDLoc DL(N);
22942       unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
22943       assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
22944              "Trying to extract from >1 concat operand?");
22945       assert(NewExtIdx % ExtNumElts == 0 &&
22946              "Extract index is not a multiple of the input vector length.");
22947       SDValue NewIndexC = DAG.getVectorIdxConstant(NewExtIdx, DL);
22948       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT,
22949                          V.getOperand(ConcatOpIdx), NewIndexC);
22950     }
22951   }
22952 
22953   if (SDValue V =
22954           foldExtractSubvectorFromShuffleVector(N, DAG, TLI, LegalOperations))
22955     return V;
22956 
22957   V = peekThroughBitcasts(V);
22958 
22959   // If the input is a build vector. Try to make a smaller build vector.
22960   if (V.getOpcode() == ISD::BUILD_VECTOR) {
22961     EVT InVT = V.getValueType();
22962     unsigned ExtractSize = NVT.getSizeInBits();
22963     unsigned EltSize = InVT.getScalarSizeInBits();
22964     // Only do this if we won't split any elements.
22965     if (ExtractSize % EltSize == 0) {
22966       unsigned NumElems = ExtractSize / EltSize;
22967       EVT EltVT = InVT.getVectorElementType();
22968       EVT ExtractVT =
22969           NumElems == 1 ? EltVT
22970                         : EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems);
22971       if ((Level < AfterLegalizeDAG ||
22972            (NumElems == 1 ||
22973             TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) &&
22974           (!LegalTypes || TLI.isTypeLegal(ExtractVT))) {
22975         unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
22976 
22977         if (NumElems == 1) {
22978           SDValue Src = V->getOperand(IdxVal);
22979           if (EltVT != Src.getValueType())
22980             Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src);
22981           return DAG.getBitcast(NVT, Src);
22982         }
22983 
22984         // Extract the pieces from the original build_vector.
22985         SDValue BuildVec = DAG.getBuildVector(ExtractVT, SDLoc(N),
22986                                               V->ops().slice(IdxVal, NumElems));
22987         return DAG.getBitcast(NVT, BuildVec);
22988       }
22989     }
22990   }
22991 
22992   if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
22993     // Handle only simple case where vector being inserted and vector
22994     // being extracted are of same size.
22995     EVT SmallVT = V.getOperand(1).getValueType();
22996     if (!NVT.bitsEq(SmallVT))
22997       return SDValue();
22998 
22999     // Combine:
23000     //    (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
23001     // Into:
23002     //    indices are equal or bit offsets are equal => V1
23003     //    otherwise => (extract_subvec V1, ExtIdx)
23004     uint64_t InsIdx = V.getConstantOperandVal(2);
23005     if (InsIdx * SmallVT.getScalarSizeInBits() ==
23006         ExtIdx * NVT.getScalarSizeInBits()) {
23007       if (LegalOperations && !TLI.isOperationLegal(ISD::BITCAST, NVT))
23008         return SDValue();
23009 
23010       return DAG.getBitcast(NVT, V.getOperand(1));
23011     }
23012     return DAG.getNode(
23013         ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT,
23014         DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)),
23015         N->getOperand(1));
23016   }
23017 
23018   if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations))
23019     return NarrowBOp;
23020 
23021   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
23022     return SDValue(N, 0);
23023 
23024   return SDValue();
23025 }
23026 
23027 /// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
23028 /// followed by concatenation. Narrow vector ops may have better performance
23029 /// than wide ops, and this can unlock further narrowing of other vector ops.
23030 /// Targets can invert this transform later if it is not profitable.
foldShuffleOfConcatUndefs(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)23031 static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
23032                                          SelectionDAG &DAG) {
23033   SDValue N0 = Shuf->getOperand(0), N1 = Shuf->getOperand(1);
23034   if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
23035       N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
23036       !N0.getOperand(1).isUndef() || !N1.getOperand(1).isUndef())
23037     return SDValue();
23038 
23039   // Split the wide shuffle mask into halves. Any mask element that is accessing
23040   // operand 1 is offset down to account for narrowing of the vectors.
23041   ArrayRef<int> Mask = Shuf->getMask();
23042   EVT VT = Shuf->getValueType(0);
23043   unsigned NumElts = VT.getVectorNumElements();
23044   unsigned HalfNumElts = NumElts / 2;
23045   SmallVector<int, 16> Mask0(HalfNumElts, -1);
23046   SmallVector<int, 16> Mask1(HalfNumElts, -1);
23047   for (unsigned i = 0; i != NumElts; ++i) {
23048     if (Mask[i] == -1)
23049       continue;
23050     // If we reference the upper (undef) subvector then the element is undef.
23051     if ((Mask[i] % NumElts) >= HalfNumElts)
23052       continue;
23053     int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
23054     if (i < HalfNumElts)
23055       Mask0[i] = M;
23056     else
23057       Mask1[i - HalfNumElts] = M;
23058   }
23059 
23060   // Ask the target if this is a valid transform.
23061   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23062   EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
23063                                 HalfNumElts);
23064   if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
23065       !TLI.isShuffleMaskLegal(Mask1, HalfVT))
23066     return SDValue();
23067 
23068   // shuffle (concat X, undef), (concat Y, undef), Mask -->
23069   // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
23070   SDValue X = N0.getOperand(0), Y = N1.getOperand(0);
23071   SDLoc DL(Shuf);
23072   SDValue Shuf0 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask0);
23073   SDValue Shuf1 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask1);
23074   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Shuf0, Shuf1);
23075 }
23076 
23077 // Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
23078 // or turn a shuffle of a single concat into simpler shuffle then concat.
partitionShuffleOfConcats(SDNode * N,SelectionDAG & DAG)23079 static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
23080   EVT VT = N->getValueType(0);
23081   unsigned NumElts = VT.getVectorNumElements();
23082 
23083   SDValue N0 = N->getOperand(0);
23084   SDValue N1 = N->getOperand(1);
23085   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
23086   ArrayRef<int> Mask = SVN->getMask();
23087 
23088   SmallVector<SDValue, 4> Ops;
23089   EVT ConcatVT = N0.getOperand(0).getValueType();
23090   unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
23091   unsigned NumConcats = NumElts / NumElemsPerConcat;
23092 
23093   auto IsUndefMaskElt = [](int i) { return i == -1; };
23094 
23095   // Special case: shuffle(concat(A,B)) can be more efficiently represented
23096   // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
23097   // half vector elements.
23098   if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
23099       llvm::all_of(Mask.slice(NumElemsPerConcat, NumElemsPerConcat),
23100                    IsUndefMaskElt)) {
23101     N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0),
23102                               N0.getOperand(1),
23103                               Mask.slice(0, NumElemsPerConcat));
23104     N1 = DAG.getUNDEF(ConcatVT);
23105     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1);
23106   }
23107 
23108   // Look at every vector that's inserted. We're looking for exact
23109   // subvector-sized copies from a concatenated vector
23110   for (unsigned I = 0; I != NumConcats; ++I) {
23111     unsigned Begin = I * NumElemsPerConcat;
23112     ArrayRef<int> SubMask = Mask.slice(Begin, NumElemsPerConcat);
23113 
23114     // Make sure we're dealing with a copy.
23115     if (llvm::all_of(SubMask, IsUndefMaskElt)) {
23116       Ops.push_back(DAG.getUNDEF(ConcatVT));
23117       continue;
23118     }
23119 
23120     int OpIdx = -1;
23121     for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
23122       if (IsUndefMaskElt(SubMask[i]))
23123         continue;
23124       if ((SubMask[i] % (int)NumElemsPerConcat) != i)
23125         return SDValue();
23126       int EltOpIdx = SubMask[i] / NumElemsPerConcat;
23127       if (0 <= OpIdx && EltOpIdx != OpIdx)
23128         return SDValue();
23129       OpIdx = EltOpIdx;
23130     }
23131     assert(0 <= OpIdx && "Unknown concat_vectors op");
23132 
23133     if (OpIdx < (int)N0.getNumOperands())
23134       Ops.push_back(N0.getOperand(OpIdx));
23135     else
23136       Ops.push_back(N1.getOperand(OpIdx - N0.getNumOperands()));
23137   }
23138 
23139   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
23140 }
23141 
23142 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
23143 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
23144 //
23145 // SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
23146 // a simplification in some sense, but it isn't appropriate in general: some
23147 // BUILD_VECTORs are substantially cheaper than others. The general case
23148 // of a BUILD_VECTOR requires inserting each element individually (or
23149 // performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
23150 // all constants is a single constant pool load.  A BUILD_VECTOR where each
23151 // element is identical is a splat.  A BUILD_VECTOR where most of the operands
23152 // are undef lowers to a small number of element insertions.
23153 //
23154 // To deal with this, we currently use a bunch of mostly arbitrary heuristics.
23155 // We don't fold shuffles where one side is a non-zero constant, and we don't
23156 // fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
23157 // non-constant operands. This seems to work out reasonably well in practice.
combineShuffleOfScalars(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI)23158 static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
23159                                        SelectionDAG &DAG,
23160                                        const TargetLowering &TLI) {
23161   EVT VT = SVN->getValueType(0);
23162   unsigned NumElts = VT.getVectorNumElements();
23163   SDValue N0 = SVN->getOperand(0);
23164   SDValue N1 = SVN->getOperand(1);
23165 
23166   if (!N0->hasOneUse())
23167     return SDValue();
23168 
23169   // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
23170   // discussed above.
23171   if (!N1.isUndef()) {
23172     if (!N1->hasOneUse())
23173       return SDValue();
23174 
23175     bool N0AnyConst = isAnyConstantBuildVector(N0);
23176     bool N1AnyConst = isAnyConstantBuildVector(N1);
23177     if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode()))
23178       return SDValue();
23179     if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode()))
23180       return SDValue();
23181   }
23182 
23183   // If both inputs are splats of the same value then we can safely merge this
23184   // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
23185   bool IsSplat = false;
23186   auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
23187   auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
23188   if (BV0 && BV1)
23189     if (SDValue Splat0 = BV0->getSplatValue())
23190       IsSplat = (Splat0 == BV1->getSplatValue());
23191 
23192   SmallVector<SDValue, 8> Ops;
23193   SmallSet<SDValue, 16> DuplicateOps;
23194   for (int M : SVN->getMask()) {
23195     SDValue Op = DAG.getUNDEF(VT.getScalarType());
23196     if (M >= 0) {
23197       int Idx = M < (int)NumElts ? M : M - NumElts;
23198       SDValue &S = (M < (int)NumElts ? N0 : N1);
23199       if (S.getOpcode() == ISD::BUILD_VECTOR) {
23200         Op = S.getOperand(Idx);
23201       } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
23202         SDValue Op0 = S.getOperand(0);
23203         Op = Idx == 0 ? Op0 : DAG.getUNDEF(Op0.getValueType());
23204       } else {
23205         // Operand can't be combined - bail out.
23206         return SDValue();
23207       }
23208     }
23209 
23210     // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
23211     // generating a splat; semantically, this is fine, but it's likely to
23212     // generate low-quality code if the target can't reconstruct an appropriate
23213     // shuffle.
23214     if (!Op.isUndef() && !isIntOrFPConstant(Op))
23215       if (!IsSplat && !DuplicateOps.insert(Op).second)
23216         return SDValue();
23217 
23218     Ops.push_back(Op);
23219   }
23220 
23221   // BUILD_VECTOR requires all inputs to be of the same type, find the
23222   // maximum type and extend them all.
23223   EVT SVT = VT.getScalarType();
23224   if (SVT.isInteger())
23225     for (SDValue &Op : Ops)
23226       SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
23227   if (SVT != VT.getScalarType())
23228     for (SDValue &Op : Ops)
23229       Op = Op.isUndef() ? DAG.getUNDEF(SVT)
23230                         : (TLI.isZExtFree(Op.getValueType(), SVT)
23231                                ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT)
23232                                : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT));
23233   return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
23234 }
23235 
23236 // Match shuffles that can be converted to *_vector_extend_in_reg.
23237 // This is often generated during legalization.
23238 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src)),
23239 // and returns the EVT to which the extension should be performed.
23240 // NOTE: this assumes that the src is the first operand of the shuffle.
canCombineShuffleToExtendVectorInreg(unsigned Opcode,EVT VT,std::function<bool (unsigned)> Match,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalTypes,bool LegalOperations)23241 static std::optional<EVT> canCombineShuffleToExtendVectorInreg(
23242     unsigned Opcode, EVT VT, std::function<bool(unsigned)> Match,
23243     SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
23244     bool LegalOperations) {
23245   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
23246 
23247   // TODO Add support for big-endian when we have a test case.
23248   if (!VT.isInteger() || IsBigEndian)
23249     return std::nullopt;
23250 
23251   unsigned NumElts = VT.getVectorNumElements();
23252   unsigned EltSizeInBits = VT.getScalarSizeInBits();
23253 
23254   // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
23255   // power-of-2 extensions as they are the most likely.
23256   // FIXME: should try Scale == NumElts case too,
23257   for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
23258     // The vector width must be a multiple of Scale.
23259     if (NumElts % Scale != 0)
23260       continue;
23261 
23262     EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
23263     EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
23264 
23265     if ((LegalTypes && !TLI.isTypeLegal(OutVT)) ||
23266         (LegalOperations && !TLI.isOperationLegalOrCustom(Opcode, OutVT)))
23267       continue;
23268 
23269     if (Match(Scale))
23270       return OutVT;
23271   }
23272 
23273   return std::nullopt;
23274 }
23275 
23276 // Match shuffles that can be converted to any_vector_extend_in_reg.
23277 // This is often generated during legalization.
23278 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)23279 static SDValue combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode *SVN,
23280                                                     SelectionDAG &DAG,
23281                                                     const TargetLowering &TLI,
23282                                                     bool LegalOperations) {
23283   EVT VT = SVN->getValueType(0);
23284   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
23285 
23286   // TODO Add support for big-endian when we have a test case.
23287   if (!VT.isInteger() || IsBigEndian)
23288     return SDValue();
23289 
23290   // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
23291   auto isAnyExtend = [NumElts = VT.getVectorNumElements(),
23292                       Mask = SVN->getMask()](unsigned Scale) {
23293     for (unsigned i = 0; i != NumElts; ++i) {
23294       if (Mask[i] < 0)
23295         continue;
23296       if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
23297         continue;
23298       return false;
23299     }
23300     return true;
23301   };
23302 
23303   unsigned Opcode = ISD::ANY_EXTEND_VECTOR_INREG;
23304   SDValue N0 = SVN->getOperand(0);
23305   // Never create an illegal type. Only create unsupported operations if we
23306   // are pre-legalization.
23307   std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
23308       Opcode, VT, isAnyExtend, DAG, TLI, /*LegalTypes=*/true, LegalOperations);
23309   if (!OutVT)
23310     return SDValue();
23311   return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT, N0));
23312 }
23313 
23314 // Match shuffles that can be converted to zero_extend_vector_inreg.
23315 // This is often generated during legalization.
23316 // e.g. v4i32 <0,z,1,u> -> (v2i64 zero_extend_vector_inreg(v4i32 src))
combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)23317 static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
23318                                                      SelectionDAG &DAG,
23319                                                      const TargetLowering &TLI,
23320                                                      bool LegalOperations) {
23321   bool LegalTypes = true;
23322   EVT VT = SVN->getValueType(0);
23323   assert(!VT.isScalableVector() && "Encountered scalable shuffle?");
23324   unsigned NumElts = VT.getVectorNumElements();
23325   unsigned EltSizeInBits = VT.getScalarSizeInBits();
23326 
23327   // TODO: add support for big-endian when we have a test case.
23328   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
23329   if (!VT.isInteger() || IsBigEndian)
23330     return SDValue();
23331 
23332   SmallVector<int, 16> Mask(SVN->getMask().begin(), SVN->getMask().end());
23333   auto ForEachDecomposedIndice = [NumElts, &Mask](auto Fn) {
23334     for (int &Indice : Mask) {
23335       if (Indice < 0)
23336         continue;
23337       int OpIdx = (unsigned)Indice < NumElts ? 0 : 1;
23338       int OpEltIdx = (unsigned)Indice < NumElts ? Indice : Indice - NumElts;
23339       Fn(Indice, OpIdx, OpEltIdx);
23340     }
23341   };
23342 
23343   // Which elements of which operand does this shuffle demand?
23344   std::array<APInt, 2> OpsDemandedElts;
23345   for (APInt &OpDemandedElts : OpsDemandedElts)
23346     OpDemandedElts = APInt::getZero(NumElts);
23347   ForEachDecomposedIndice(
23348       [&OpsDemandedElts](int &Indice, int OpIdx, int OpEltIdx) {
23349         OpsDemandedElts[OpIdx].setBit(OpEltIdx);
23350       });
23351 
23352   // Element-wise(!), which of these demanded elements are know to be zero?
23353   std::array<APInt, 2> OpsKnownZeroElts;
23354   for (auto I : zip(SVN->ops(), OpsDemandedElts, OpsKnownZeroElts))
23355     std::get<2>(I) =
23356         DAG.computeVectorKnownZeroElements(std::get<0>(I), std::get<1>(I));
23357 
23358   // Manifest zeroable element knowledge in the shuffle mask.
23359   // NOTE: we don't have 'zeroable' sentinel value in generic DAG,
23360   //       this is a local invention, but it won't leak into DAG.
23361   // FIXME: should we not manifest them, but just check when matching?
23362   bool HadZeroableElts = false;
23363   ForEachDecomposedIndice([&OpsKnownZeroElts, &HadZeroableElts](
23364                               int &Indice, int OpIdx, int OpEltIdx) {
23365     if (OpsKnownZeroElts[OpIdx][OpEltIdx]) {
23366       Indice = -2; // Zeroable element.
23367       HadZeroableElts = true;
23368     }
23369   });
23370 
23371   // Don't proceed unless we've refined at least one zeroable mask indice.
23372   // If we didn't, then we are still trying to match the same shuffle mask
23373   // we previously tried to match as ISD::ANY_EXTEND_VECTOR_INREG,
23374   // and evidently failed. Proceeding will lead to endless combine loops.
23375   if (!HadZeroableElts)
23376     return SDValue();
23377 
23378   // The shuffle may be more fine-grained than we want. Widen elements first.
23379   // FIXME: should we do this before manifesting zeroable shuffle mask indices?
23380   SmallVector<int, 16> ScaledMask;
23381   getShuffleMaskWithWidestElts(Mask, ScaledMask);
23382   assert(Mask.size() >= ScaledMask.size() &&
23383          Mask.size() % ScaledMask.size() == 0 && "Unexpected mask widening.");
23384   int Prescale = Mask.size() / ScaledMask.size();
23385 
23386   NumElts = ScaledMask.size();
23387   EltSizeInBits *= Prescale;
23388 
23389   EVT PrescaledVT = EVT::getVectorVT(
23390       *DAG.getContext(), EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits),
23391       NumElts);
23392 
23393   if (LegalTypes && !TLI.isTypeLegal(PrescaledVT) && TLI.isTypeLegal(VT))
23394     return SDValue();
23395 
23396   // For example,
23397   // shuffle<0,z,1,-1> == (v2i64 zero_extend_vector_inreg(v4i32))
23398   // But not shuffle<z,z,1,-1> and not shuffle<0,z,z,-1> ! (for same types)
23399   auto isZeroExtend = [NumElts, &ScaledMask](unsigned Scale) {
23400     assert(Scale >= 2 && Scale <= NumElts && NumElts % Scale == 0 &&
23401            "Unexpected mask scaling factor.");
23402     ArrayRef<int> Mask = ScaledMask;
23403     for (unsigned SrcElt = 0, NumSrcElts = NumElts / Scale;
23404          SrcElt != NumSrcElts; ++SrcElt) {
23405       // Analyze the shuffle mask in Scale-sized chunks.
23406       ArrayRef<int> MaskChunk = Mask.take_front(Scale);
23407       assert(MaskChunk.size() == Scale && "Unexpected mask size.");
23408       Mask = Mask.drop_front(MaskChunk.size());
23409       // The first indice in this chunk must be SrcElt, but not zero!
23410       // FIXME: undef should be fine, but that results in more-defined result.
23411       if (int FirstIndice = MaskChunk[0]; (unsigned)FirstIndice != SrcElt)
23412         return false;
23413       // The rest of the indices in this chunk must be zeros.
23414       // FIXME: undef should be fine, but that results in more-defined result.
23415       if (!all_of(MaskChunk.drop_front(1),
23416                   [](int Indice) { return Indice == -2; }))
23417         return false;
23418     }
23419     assert(Mask.empty() && "Did not process the whole mask?");
23420     return true;
23421   };
23422 
23423   unsigned Opcode = ISD::ZERO_EXTEND_VECTOR_INREG;
23424   for (bool Commuted : {false, true}) {
23425     SDValue Op = SVN->getOperand(!Commuted ? 0 : 1);
23426     if (Commuted)
23427       ShuffleVectorSDNode::commuteMask(ScaledMask);
23428     std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
23429         Opcode, PrescaledVT, isZeroExtend, DAG, TLI, LegalTypes,
23430         LegalOperations);
23431     if (OutVT)
23432       return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT,
23433                                             DAG.getBitcast(PrescaledVT, Op)));
23434   }
23435   return SDValue();
23436 }
23437 
23438 // Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
23439 // each source element of a large type into the lowest elements of a smaller
23440 // destination type. This is often generated during legalization.
23441 // If the source node itself was a '*_extend_vector_inreg' node then we should
23442 // then be able to remove it.
combineTruncationShuffle(ShuffleVectorSDNode * SVN,SelectionDAG & DAG)23443 static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
23444                                         SelectionDAG &DAG) {
23445   EVT VT = SVN->getValueType(0);
23446   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
23447 
23448   // TODO Add support for big-endian when we have a test case.
23449   if (!VT.isInteger() || IsBigEndian)
23450     return SDValue();
23451 
23452   SDValue N0 = peekThroughBitcasts(SVN->getOperand(0));
23453 
23454   unsigned Opcode = N0.getOpcode();
23455   if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG &&
23456       Opcode != ISD::SIGN_EXTEND_VECTOR_INREG &&
23457       Opcode != ISD::ZERO_EXTEND_VECTOR_INREG)
23458     return SDValue();
23459 
23460   SDValue N00 = N0.getOperand(0);
23461   ArrayRef<int> Mask = SVN->getMask();
23462   unsigned NumElts = VT.getVectorNumElements();
23463   unsigned EltSizeInBits = VT.getScalarSizeInBits();
23464   unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
23465   unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
23466 
23467   if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
23468     return SDValue();
23469   unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
23470 
23471   // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
23472   // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
23473   // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
23474   auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
23475     for (unsigned i = 0; i != NumElts; ++i) {
23476       if (Mask[i] < 0)
23477         continue;
23478       if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
23479         continue;
23480       return false;
23481     }
23482     return true;
23483   };
23484 
23485   // At the moment we just handle the case where we've truncated back to the
23486   // same size as before the extension.
23487   // TODO: handle more extension/truncation cases as cases arise.
23488   if (EltSizeInBits != ExtSrcSizeInBits)
23489     return SDValue();
23490 
23491   // We can remove *extend_vector_inreg only if the truncation happens at
23492   // the same scale as the extension.
23493   if (isTruncate(ExtScale))
23494     return DAG.getBitcast(VT, N00);
23495 
23496   return SDValue();
23497 }
23498 
23499 // Combine shuffles of splat-shuffles of the form:
23500 // shuffle (shuffle V, undef, splat-mask), undef, M
23501 // If splat-mask contains undef elements, we need to be careful about
23502 // introducing undef's in the folded mask which are not the result of composing
23503 // the masks of the shuffles.
combineShuffleOfSplatVal(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)23504 static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
23505                                         SelectionDAG &DAG) {
23506   EVT VT = Shuf->getValueType(0);
23507   unsigned NumElts = VT.getVectorNumElements();
23508 
23509   if (!Shuf->getOperand(1).isUndef())
23510     return SDValue();
23511 
23512   // See if this unary non-splat shuffle actually *is* a splat shuffle,
23513   // in disguise, with all demanded elements being identical.
23514   // FIXME: this can be done per-operand.
23515   if (!Shuf->isSplat()) {
23516     APInt DemandedElts(NumElts, 0);
23517     for (int Idx : Shuf->getMask()) {
23518       if (Idx < 0)
23519         continue; // Ignore sentinel indices.
23520       assert((unsigned)Idx < NumElts && "Out-of-bounds shuffle indice?");
23521       DemandedElts.setBit(Idx);
23522     }
23523     assert(DemandedElts.countPopulation() > 1 && "Is a splat shuffle already?");
23524     APInt UndefElts;
23525     if (DAG.isSplatValue(Shuf->getOperand(0), DemandedElts, UndefElts)) {
23526       // Even if all demanded elements are splat, some of them could be undef.
23527       // Which lowest demanded element is *not* known-undef?
23528       std::optional<unsigned> MinNonUndefIdx;
23529       for (int Idx : Shuf->getMask()) {
23530         if (Idx < 0 || UndefElts[Idx])
23531           continue; // Ignore sentinel indices, and undef elements.
23532         MinNonUndefIdx = std::min<unsigned>(Idx, MinNonUndefIdx.value_or(~0U));
23533       }
23534       if (!MinNonUndefIdx)
23535         return DAG.getUNDEF(VT); // All undef - result is undef.
23536       assert(*MinNonUndefIdx < NumElts && "Expected valid element index.");
23537       SmallVector<int, 8> SplatMask(Shuf->getMask().begin(),
23538                                     Shuf->getMask().end());
23539       for (int &Idx : SplatMask) {
23540         if (Idx < 0)
23541           continue; // Passthrough sentinel indices.
23542         // Otherwise, just pick the lowest demanded non-undef element.
23543         // Or sentinel undef, if we know we'd pick a known-undef element.
23544         Idx = UndefElts[Idx] ? -1 : *MinNonUndefIdx;
23545       }
23546       assert(SplatMask != Shuf->getMask() && "Expected mask to change!");
23547       return DAG.getVectorShuffle(VT, SDLoc(Shuf), Shuf->getOperand(0),
23548                                   Shuf->getOperand(1), SplatMask);
23549     }
23550   }
23551 
23552   // If the inner operand is a known splat with no undefs, just return that directly.
23553   // TODO: Create DemandedElts mask from Shuf's mask.
23554   // TODO: Allow undef elements and merge with the shuffle code below.
23555   if (DAG.isSplatValue(Shuf->getOperand(0), /*AllowUndefs*/ false))
23556     return Shuf->getOperand(0);
23557 
23558   auto *Splat = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
23559   if (!Splat || !Splat->isSplat())
23560     return SDValue();
23561 
23562   ArrayRef<int> ShufMask = Shuf->getMask();
23563   ArrayRef<int> SplatMask = Splat->getMask();
23564   assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
23565 
23566   // Prefer simplifying to the splat-shuffle, if possible. This is legal if
23567   // every undef mask element in the splat-shuffle has a corresponding undef
23568   // element in the user-shuffle's mask or if the composition of mask elements
23569   // would result in undef.
23570   // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
23571   // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
23572   //   In this case it is not legal to simplify to the splat-shuffle because we
23573   //   may be exposing the users of the shuffle an undef element at index 1
23574   //   which was not there before the combine.
23575   // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
23576   //   In this case the composition of masks yields SplatMask, so it's ok to
23577   //   simplify to the splat-shuffle.
23578   // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
23579   //   In this case the composed mask includes all undef elements of SplatMask
23580   //   and in addition sets element zero to undef. It is safe to simplify to
23581   //   the splat-shuffle.
23582   auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
23583                                        ArrayRef<int> SplatMask) {
23584     for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
23585       if (UserMask[i] != -1 && SplatMask[i] == -1 &&
23586           SplatMask[UserMask[i]] != -1)
23587         return false;
23588     return true;
23589   };
23590   if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
23591     return Shuf->getOperand(0);
23592 
23593   // Create a new shuffle with a mask that is composed of the two shuffles'
23594   // masks.
23595   SmallVector<int, 32> NewMask;
23596   for (int Idx : ShufMask)
23597     NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
23598 
23599   return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
23600                               Splat->getOperand(0), Splat->getOperand(1),
23601                               NewMask);
23602 }
23603 
23604 // Combine shuffles of bitcasts into a shuffle of the bitcast type, providing
23605 // the mask can be treated as a larger type.
combineShuffleOfBitcast(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)23606 static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN,
23607                                        SelectionDAG &DAG,
23608                                        const TargetLowering &TLI,
23609                                        bool LegalOperations) {
23610   SDValue Op0 = SVN->getOperand(0);
23611   SDValue Op1 = SVN->getOperand(1);
23612   EVT VT = SVN->getValueType(0);
23613   if (Op0.getOpcode() != ISD::BITCAST)
23614     return SDValue();
23615   EVT InVT = Op0.getOperand(0).getValueType();
23616   if (!InVT.isVector() ||
23617       (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST ||
23618                           Op1.getOperand(0).getValueType() != InVT)))
23619     return SDValue();
23620   if (isAnyConstantBuildVector(Op0.getOperand(0)) &&
23621       (Op1.isUndef() || isAnyConstantBuildVector(Op1.getOperand(0))))
23622     return SDValue();
23623 
23624   int VTLanes = VT.getVectorNumElements();
23625   int InLanes = InVT.getVectorNumElements();
23626   if (VTLanes <= InLanes || VTLanes % InLanes != 0 ||
23627       (LegalOperations &&
23628        !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, InVT)))
23629     return SDValue();
23630   int Factor = VTLanes / InLanes;
23631 
23632   // Check that each group of lanes in the mask are either undef or make a valid
23633   // mask for the wider lane type.
23634   ArrayRef<int> Mask = SVN->getMask();
23635   SmallVector<int> NewMask;
23636   if (!widenShuffleMaskElts(Factor, Mask, NewMask))
23637     return SDValue();
23638 
23639   if (!TLI.isShuffleMaskLegal(NewMask, InVT))
23640     return SDValue();
23641 
23642   // Create the new shuffle with the new mask and bitcast it back to the
23643   // original type.
23644   SDLoc DL(SVN);
23645   Op0 = Op0.getOperand(0);
23646   Op1 = Op1.isUndef() ? DAG.getUNDEF(InVT) : Op1.getOperand(0);
23647   SDValue NewShuf = DAG.getVectorShuffle(InVT, DL, Op0, Op1, NewMask);
23648   return DAG.getBitcast(VT, NewShuf);
23649 }
23650 
23651 /// Combine shuffle of shuffle of the form:
23652 /// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
formSplatFromShuffles(ShuffleVectorSDNode * OuterShuf,SelectionDAG & DAG)23653 static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
23654                                      SelectionDAG &DAG) {
23655   if (!OuterShuf->getOperand(1).isUndef())
23656     return SDValue();
23657   auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0));
23658   if (!InnerShuf || !InnerShuf->getOperand(1).isUndef())
23659     return SDValue();
23660 
23661   ArrayRef<int> OuterMask = OuterShuf->getMask();
23662   ArrayRef<int> InnerMask = InnerShuf->getMask();
23663   unsigned NumElts = OuterMask.size();
23664   assert(NumElts == InnerMask.size() && "Mask length mismatch");
23665   SmallVector<int, 32> CombinedMask(NumElts, -1);
23666   int SplatIndex = -1;
23667   for (unsigned i = 0; i != NumElts; ++i) {
23668     // Undef lanes remain undef.
23669     int OuterMaskElt = OuterMask[i];
23670     if (OuterMaskElt == -1)
23671       continue;
23672 
23673     // Peek through the shuffle masks to get the underlying source element.
23674     int InnerMaskElt = InnerMask[OuterMaskElt];
23675     if (InnerMaskElt == -1)
23676       continue;
23677 
23678     // Initialize the splatted element.
23679     if (SplatIndex == -1)
23680       SplatIndex = InnerMaskElt;
23681 
23682     // Non-matching index - this is not a splat.
23683     if (SplatIndex != InnerMaskElt)
23684       return SDValue();
23685 
23686     CombinedMask[i] = InnerMaskElt;
23687   }
23688   assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
23689           getSplatIndex(CombinedMask) != -1) &&
23690          "Expected a splat mask");
23691 
23692   // TODO: The transform may be a win even if the mask is not legal.
23693   EVT VT = OuterShuf->getValueType(0);
23694   assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
23695   if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
23696     return SDValue();
23697 
23698   return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0),
23699                               InnerShuf->getOperand(1), CombinedMask);
23700 }
23701 
23702 /// If the shuffle mask is taking exactly one element from the first vector
23703 /// operand and passing through all other elements from the second vector
23704 /// operand, return the index of the mask element that is choosing an element
23705 /// from the first operand. Otherwise, return -1.
getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask)23706 static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
23707   int MaskSize = Mask.size();
23708   int EltFromOp0 = -1;
23709   // TODO: This does not match if there are undef elements in the shuffle mask.
23710   // Should we ignore undefs in the shuffle mask instead? The trade-off is
23711   // removing an instruction (a shuffle), but losing the knowledge that some
23712   // vector lanes are not needed.
23713   for (int i = 0; i != MaskSize; ++i) {
23714     if (Mask[i] >= 0 && Mask[i] < MaskSize) {
23715       // We're looking for a shuffle of exactly one element from operand 0.
23716       if (EltFromOp0 != -1)
23717         return -1;
23718       EltFromOp0 = i;
23719     } else if (Mask[i] != i + MaskSize) {
23720       // Nothing from operand 1 can change lanes.
23721       return -1;
23722     }
23723   }
23724   return EltFromOp0;
23725 }
23726 
23727 /// If a shuffle inserts exactly one element from a source vector operand into
23728 /// another vector operand and we can access the specified element as a scalar,
23729 /// then we can eliminate the shuffle.
replaceShuffleOfInsert(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)23730 static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
23731                                       SelectionDAG &DAG) {
23732   // First, check if we are taking one element of a vector and shuffling that
23733   // element into another vector.
23734   ArrayRef<int> Mask = Shuf->getMask();
23735   SmallVector<int, 16> CommutedMask(Mask);
23736   SDValue Op0 = Shuf->getOperand(0);
23737   SDValue Op1 = Shuf->getOperand(1);
23738   int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
23739   if (ShufOp0Index == -1) {
23740     // Commute mask and check again.
23741     ShuffleVectorSDNode::commuteMask(CommutedMask);
23742     ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask);
23743     if (ShufOp0Index == -1)
23744       return SDValue();
23745     // Commute operands to match the commuted shuffle mask.
23746     std::swap(Op0, Op1);
23747     Mask = CommutedMask;
23748   }
23749 
23750   // The shuffle inserts exactly one element from operand 0 into operand 1.
23751   // Now see if we can access that element as a scalar via a real insert element
23752   // instruction.
23753   // TODO: We can try harder to locate the element as a scalar. Examples: it
23754   // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
23755   assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
23756          "Shuffle mask value must be from operand 0");
23757   if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
23758     return SDValue();
23759 
23760   auto *InsIndexC = dyn_cast<ConstantSDNode>(Op0.getOperand(2));
23761   if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
23762     return SDValue();
23763 
23764   // There's an existing insertelement with constant insertion index, so we
23765   // don't need to check the legality/profitability of a replacement operation
23766   // that differs at most in the constant value. The target should be able to
23767   // lower any of those in a similar way. If not, legalization will expand this
23768   // to a scalar-to-vector plus shuffle.
23769   //
23770   // Note that the shuffle may move the scalar from the position that the insert
23771   // element used. Therefore, our new insert element occurs at the shuffle's
23772   // mask index value, not the insert's index value.
23773   // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
23774   SDValue NewInsIndex = DAG.getVectorIdxConstant(ShufOp0Index, SDLoc(Shuf));
23775   return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
23776                      Op1, Op0.getOperand(1), NewInsIndex);
23777 }
23778 
23779 /// If we have a unary shuffle of a shuffle, see if it can be folded away
23780 /// completely. This has the potential to lose undef knowledge because the first
23781 /// shuffle may not have an undef mask element where the second one does. So
23782 /// only call this after doing simplifications based on demanded elements.
simplifyShuffleOfShuffle(ShuffleVectorSDNode * Shuf)23783 static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
23784   // shuf (shuf0 X, Y, Mask0), undef, Mask
23785   auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
23786   if (!Shuf0 || !Shuf->getOperand(1).isUndef())
23787     return SDValue();
23788 
23789   ArrayRef<int> Mask = Shuf->getMask();
23790   ArrayRef<int> Mask0 = Shuf0->getMask();
23791   for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
23792     // Ignore undef elements.
23793     if (Mask[i] == -1)
23794       continue;
23795     assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
23796 
23797     // Is the element of the shuffle operand chosen by this shuffle the same as
23798     // the element chosen by the shuffle operand itself?
23799     if (Mask0[Mask[i]] != Mask0[i])
23800       return SDValue();
23801   }
23802   // Every element of this shuffle is identical to the result of the previous
23803   // shuffle, so we can replace this value.
23804   return Shuf->getOperand(0);
23805 }
23806 
visitVECTOR_SHUFFLE(SDNode * N)23807 SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
23808   EVT VT = N->getValueType(0);
23809   unsigned NumElts = VT.getVectorNumElements();
23810 
23811   SDValue N0 = N->getOperand(0);
23812   SDValue N1 = N->getOperand(1);
23813 
23814   assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
23815 
23816   // Canonicalize shuffle undef, undef -> undef
23817   if (N0.isUndef() && N1.isUndef())
23818     return DAG.getUNDEF(VT);
23819 
23820   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
23821 
23822   // Canonicalize shuffle v, v -> v, undef
23823   if (N0 == N1)
23824     return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT),
23825                                 createUnaryMask(SVN->getMask(), NumElts));
23826 
23827   // Canonicalize shuffle undef, v -> v, undef.  Commute the shuffle mask.
23828   if (N0.isUndef())
23829     return DAG.getCommutedVectorShuffle(*SVN);
23830 
23831   // Remove references to rhs if it is undef
23832   if (N1.isUndef()) {
23833     bool Changed = false;
23834     SmallVector<int, 8> NewMask;
23835     for (unsigned i = 0; i != NumElts; ++i) {
23836       int Idx = SVN->getMaskElt(i);
23837       if (Idx >= (int)NumElts) {
23838         Idx = -1;
23839         Changed = true;
23840       }
23841       NewMask.push_back(Idx);
23842     }
23843     if (Changed)
23844       return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
23845   }
23846 
23847   if (SDValue InsElt = replaceShuffleOfInsert(SVN, DAG))
23848     return InsElt;
23849 
23850   // A shuffle of a single vector that is a splatted value can always be folded.
23851   if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
23852     return V;
23853 
23854   if (SDValue V = formSplatFromShuffles(SVN, DAG))
23855     return V;
23856 
23857   // If it is a splat, check if the argument vector is another splat or a
23858   // build_vector.
23859   if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
23860     int SplatIndex = SVN->getSplatIndex();
23861     if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, SplatIndex) &&
23862         TLI.isBinOp(N0.getOpcode()) && N0->getNumValues() == 1) {
23863       // splat (vector_bo L, R), Index -->
23864       // splat (scalar_bo (extelt L, Index), (extelt R, Index))
23865       SDValue L = N0.getOperand(0), R = N0.getOperand(1);
23866       SDLoc DL(N);
23867       EVT EltVT = VT.getScalarType();
23868       SDValue Index = DAG.getVectorIdxConstant(SplatIndex, DL);
23869       SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index);
23870       SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index);
23871       SDValue NewBO =
23872           DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR, N0->getFlags());
23873       SDValue Insert = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, NewBO);
23874       SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
23875       return DAG.getVectorShuffle(VT, DL, Insert, DAG.getUNDEF(VT), ZeroMask);
23876     }
23877 
23878     // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x)
23879     // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x)
23880     if ((!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) &&
23881         N0.hasOneUse()) {
23882       if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0)
23883         return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(0));
23884 
23885       if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT)
23886         if (auto *Idx = dyn_cast<ConstantSDNode>(N0.getOperand(2)))
23887           if (Idx->getAPIntValue() == SplatIndex)
23888             return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(1));
23889 
23890       // Look through a bitcast if LE and splatting lane 0, through to a
23891       // scalar_to_vector or a build_vector.
23892       if (N0.getOpcode() == ISD::BITCAST && N0.getOperand(0).hasOneUse() &&
23893           SplatIndex == 0 && DAG.getDataLayout().isLittleEndian() &&
23894           (N0.getOperand(0).getOpcode() == ISD::SCALAR_TO_VECTOR ||
23895            N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR)) {
23896         EVT N00VT = N0.getOperand(0).getValueType();
23897         if (VT.getScalarSizeInBits() <= N00VT.getScalarSizeInBits() &&
23898             VT.isInteger() && N00VT.isInteger()) {
23899           EVT InVT =
23900               TLI.getTypeToTransformTo(*DAG.getContext(), VT.getScalarType());
23901           SDValue Op = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0),
23902                                           SDLoc(N), InVT);
23903           return DAG.getSplatBuildVector(VT, SDLoc(N), Op);
23904         }
23905       }
23906     }
23907 
23908     // If this is a bit convert that changes the element type of the vector but
23909     // not the number of vector elements, look through it.  Be careful not to
23910     // look though conversions that change things like v4f32 to v2f64.
23911     SDNode *V = N0.getNode();
23912     if (V->getOpcode() == ISD::BITCAST) {
23913       SDValue ConvInput = V->getOperand(0);
23914       if (ConvInput.getValueType().isVector() &&
23915           ConvInput.getValueType().getVectorNumElements() == NumElts)
23916         V = ConvInput.getNode();
23917     }
23918 
23919     if (V->getOpcode() == ISD::BUILD_VECTOR) {
23920       assert(V->getNumOperands() == NumElts &&
23921              "BUILD_VECTOR has wrong number of operands");
23922       SDValue Base;
23923       bool AllSame = true;
23924       for (unsigned i = 0; i != NumElts; ++i) {
23925         if (!V->getOperand(i).isUndef()) {
23926           Base = V->getOperand(i);
23927           break;
23928         }
23929       }
23930       // Splat of <u, u, u, u>, return <u, u, u, u>
23931       if (!Base.getNode())
23932         return N0;
23933       for (unsigned i = 0; i != NumElts; ++i) {
23934         if (V->getOperand(i) != Base) {
23935           AllSame = false;
23936           break;
23937         }
23938       }
23939       // Splat of <x, x, x, x>, return <x, x, x, x>
23940       if (AllSame)
23941         return N0;
23942 
23943       // Canonicalize any other splat as a build_vector.
23944       SDValue Splatted = V->getOperand(SplatIndex);
23945       SmallVector<SDValue, 8> Ops(NumElts, Splatted);
23946       SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops);
23947 
23948       // We may have jumped through bitcasts, so the type of the
23949       // BUILD_VECTOR may not match the type of the shuffle.
23950       if (V->getValueType(0) != VT)
23951         NewBV = DAG.getBitcast(VT, NewBV);
23952       return NewBV;
23953     }
23954   }
23955 
23956   // Simplify source operands based on shuffle mask.
23957   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
23958     return SDValue(N, 0);
23959 
23960   // This is intentionally placed after demanded elements simplification because
23961   // it could eliminate knowledge of undef elements created by this shuffle.
23962   if (SDValue ShufOp = simplifyShuffleOfShuffle(SVN))
23963     return ShufOp;
23964 
23965   // Match shuffles that can be converted to any_vector_extend_in_reg.
23966   if (SDValue V =
23967           combineShuffleToAnyExtendVectorInreg(SVN, DAG, TLI, LegalOperations))
23968     return V;
23969 
23970   // Combine "truncate_vector_in_reg" style shuffles.
23971   if (SDValue V = combineTruncationShuffle(SVN, DAG))
23972     return V;
23973 
23974   if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
23975       Level < AfterLegalizeVectorOps &&
23976       (N1.isUndef() ||
23977       (N1.getOpcode() == ISD::CONCAT_VECTORS &&
23978        N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) {
23979     if (SDValue V = partitionShuffleOfConcats(N, DAG))
23980       return V;
23981   }
23982 
23983   // A shuffle of a concat of the same narrow vector can be reduced to use
23984   // only low-half elements of a concat with undef:
23985   // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
23986   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
23987       N0.getNumOperands() == 2 &&
23988       N0.getOperand(0) == N0.getOperand(1)) {
23989     int HalfNumElts = (int)NumElts / 2;
23990     SmallVector<int, 8> NewMask;
23991     for (unsigned i = 0; i != NumElts; ++i) {
23992       int Idx = SVN->getMaskElt(i);
23993       if (Idx >= HalfNumElts) {
23994         assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
23995         Idx -= HalfNumElts;
23996       }
23997       NewMask.push_back(Idx);
23998     }
23999     if (TLI.isShuffleMaskLegal(NewMask, VT)) {
24000       SDValue UndefVec = DAG.getUNDEF(N0.getOperand(0).getValueType());
24001       SDValue NewCat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
24002                                    N0.getOperand(0), UndefVec);
24003       return DAG.getVectorShuffle(VT, SDLoc(N), NewCat, N1, NewMask);
24004     }
24005   }
24006 
24007   // See if we can replace a shuffle with an insert_subvector.
24008   // e.g. v2i32 into v8i32:
24009   // shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
24010   // --> insert_subvector(lhs,rhs1,4).
24011   if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
24012       TLI.isOperationLegalOrCustom(ISD::INSERT_SUBVECTOR, VT)) {
24013     auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
24014       // Ensure RHS subvectors are legal.
24015       assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
24016       EVT SubVT = RHS.getOperand(0).getValueType();
24017       int NumSubVecs = RHS.getNumOperands();
24018       int NumSubElts = SubVT.getVectorNumElements();
24019       assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
24020       if (!TLI.isTypeLegal(SubVT))
24021         return SDValue();
24022 
24023       // Don't bother if we have an unary shuffle (matches undef + LHS elts).
24024       if (all_of(Mask, [NumElts](int M) { return M < (int)NumElts; }))
24025         return SDValue();
24026 
24027       // Search [NumSubElts] spans for RHS sequence.
24028       // TODO: Can we avoid nested loops to increase performance?
24029       SmallVector<int> InsertionMask(NumElts);
24030       for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
24031         for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
24032           // Reset mask to identity.
24033           std::iota(InsertionMask.begin(), InsertionMask.end(), 0);
24034 
24035           // Add subvector insertion.
24036           std::iota(InsertionMask.begin() + SubIdx,
24037                     InsertionMask.begin() + SubIdx + NumSubElts,
24038                     NumElts + (SubVec * NumSubElts));
24039 
24040           // See if the shuffle mask matches the reference insertion mask.
24041           bool MatchingShuffle = true;
24042           for (int i = 0; i != (int)NumElts; ++i) {
24043             int ExpectIdx = InsertionMask[i];
24044             int ActualIdx = Mask[i];
24045             if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
24046               MatchingShuffle = false;
24047               break;
24048             }
24049           }
24050 
24051           if (MatchingShuffle)
24052             return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, LHS,
24053                                RHS.getOperand(SubVec),
24054                                DAG.getVectorIdxConstant(SubIdx, SDLoc(N)));
24055         }
24056       }
24057       return SDValue();
24058     };
24059     ArrayRef<int> Mask = SVN->getMask();
24060     if (N1.getOpcode() == ISD::CONCAT_VECTORS)
24061       if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
24062         return InsertN1;
24063     if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
24064       SmallVector<int> CommuteMask(Mask);
24065       ShuffleVectorSDNode::commuteMask(CommuteMask);
24066       if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
24067         return InsertN0;
24068     }
24069   }
24070 
24071   // If we're not performing a select/blend shuffle, see if we can convert the
24072   // shuffle into a AND node, with all the out-of-lane elements are known zero.
24073   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
24074     bool IsInLaneMask = true;
24075     ArrayRef<int> Mask = SVN->getMask();
24076     SmallVector<int, 16> ClearMask(NumElts, -1);
24077     APInt DemandedLHS = APInt::getNullValue(NumElts);
24078     APInt DemandedRHS = APInt::getNullValue(NumElts);
24079     for (int I = 0; I != (int)NumElts; ++I) {
24080       int M = Mask[I];
24081       if (M < 0)
24082         continue;
24083       ClearMask[I] = M == I ? I : (I + NumElts);
24084       IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
24085       if (M != I) {
24086         APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
24087         Demanded.setBit(M % NumElts);
24088       }
24089     }
24090     // TODO: Should we try to mask with N1 as well?
24091     if (!IsInLaneMask &&
24092         (!DemandedLHS.isNullValue() || !DemandedRHS.isNullValue()) &&
24093         (DemandedLHS.isNullValue() ||
24094          DAG.MaskedVectorIsZero(N0, DemandedLHS)) &&
24095         (DemandedRHS.isNullValue() ||
24096          DAG.MaskedVectorIsZero(N1, DemandedRHS))) {
24097       SDLoc DL(N);
24098       EVT IntVT = VT.changeVectorElementTypeToInteger();
24099       EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
24100       // Transform the type to a legal type so that the buildvector constant
24101       // elements are not illegal. Make sure that the result is larger than the
24102       // original type, incase the value is split into two (eg i64->i32).
24103       if (!TLI.isTypeLegal(IntSVT) && LegalTypes)
24104         IntSVT = TLI.getTypeToTransformTo(*DAG.getContext(), IntSVT);
24105       if (IntSVT.getSizeInBits() >= IntVT.getScalarSizeInBits()) {
24106         SDValue ZeroElt = DAG.getConstant(0, DL, IntSVT);
24107         SDValue AllOnesElt = DAG.getAllOnesConstant(DL, IntSVT);
24108         SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(IntSVT));
24109         for (int I = 0; I != (int)NumElts; ++I)
24110           if (0 <= Mask[I])
24111             AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
24112 
24113         // See if a clear mask is legal instead of going via
24114         // XformToShuffleWithZero which loses UNDEF mask elements.
24115         if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
24116           return DAG.getBitcast(
24117               VT, DAG.getVectorShuffle(IntVT, DL, DAG.getBitcast(IntVT, N0),
24118                                       DAG.getConstant(0, DL, IntVT), ClearMask));
24119 
24120         if (TLI.isOperationLegalOrCustom(ISD::AND, IntVT))
24121           return DAG.getBitcast(
24122               VT, DAG.getNode(ISD::AND, DL, IntVT, DAG.getBitcast(IntVT, N0),
24123                               DAG.getBuildVector(IntVT, DL, AndMask)));
24124       }
24125     }
24126   }
24127 
24128   // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
24129   // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
24130   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
24131     if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
24132       return Res;
24133 
24134   // If this shuffle only has a single input that is a bitcasted shuffle,
24135   // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
24136   // back to their original types.
24137   if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
24138       N1.isUndef() && Level < AfterLegalizeVectorOps &&
24139       TLI.isTypeLegal(VT)) {
24140 
24141     SDValue BC0 = peekThroughOneUseBitcasts(N0);
24142     if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
24143       EVT SVT = VT.getScalarType();
24144       EVT InnerVT = BC0->getValueType(0);
24145       EVT InnerSVT = InnerVT.getScalarType();
24146 
24147       // Determine which shuffle works with the smaller scalar type.
24148       EVT ScaleVT = SVT.bitsLT(InnerSVT) ? VT : InnerVT;
24149       EVT ScaleSVT = ScaleVT.getScalarType();
24150 
24151       if (TLI.isTypeLegal(ScaleVT) &&
24152           0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
24153           0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
24154         int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
24155         int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
24156 
24157         // Scale the shuffle masks to the smaller scalar type.
24158         ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(BC0);
24159         SmallVector<int, 8> InnerMask;
24160         SmallVector<int, 8> OuterMask;
24161         narrowShuffleMaskElts(InnerScale, InnerSVN->getMask(), InnerMask);
24162         narrowShuffleMaskElts(OuterScale, SVN->getMask(), OuterMask);
24163 
24164         // Merge the shuffle masks.
24165         SmallVector<int, 8> NewMask;
24166         for (int M : OuterMask)
24167           NewMask.push_back(M < 0 ? -1 : InnerMask[M]);
24168 
24169         // Test for shuffle mask legality over both commutations.
24170         SDValue SV0 = BC0->getOperand(0);
24171         SDValue SV1 = BC0->getOperand(1);
24172         bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
24173         if (!LegalMask) {
24174           std::swap(SV0, SV1);
24175           ShuffleVectorSDNode::commuteMask(NewMask);
24176           LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
24177         }
24178 
24179         if (LegalMask) {
24180           SV0 = DAG.getBitcast(ScaleVT, SV0);
24181           SV1 = DAG.getBitcast(ScaleVT, SV1);
24182           return DAG.getBitcast(
24183               VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
24184         }
24185       }
24186     }
24187   }
24188 
24189   // Match shuffles of bitcasts, so long as the mask can be treated as the
24190   // larger type.
24191   if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations))
24192     return V;
24193 
24194   // Compute the combined shuffle mask for a shuffle with SV0 as the first
24195   // operand, and SV1 as the second operand.
24196   // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
24197   //      Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
24198   auto MergeInnerShuffle =
24199       [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
24200                      ShuffleVectorSDNode *OtherSVN, SDValue N1,
24201                      const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
24202                      SmallVectorImpl<int> &Mask) -> bool {
24203     // Don't try to fold splats; they're likely to simplify somehow, or they
24204     // might be free.
24205     if (OtherSVN->isSplat())
24206       return false;
24207 
24208     SV0 = SV1 = SDValue();
24209     Mask.clear();
24210 
24211     for (unsigned i = 0; i != NumElts; ++i) {
24212       int Idx = SVN->getMaskElt(i);
24213       if (Idx < 0) {
24214         // Propagate Undef.
24215         Mask.push_back(Idx);
24216         continue;
24217       }
24218 
24219       if (Commute)
24220         Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
24221 
24222       SDValue CurrentVec;
24223       if (Idx < (int)NumElts) {
24224         // This shuffle index refers to the inner shuffle N0. Lookup the inner
24225         // shuffle mask to identify which vector is actually referenced.
24226         Idx = OtherSVN->getMaskElt(Idx);
24227         if (Idx < 0) {
24228           // Propagate Undef.
24229           Mask.push_back(Idx);
24230           continue;
24231         }
24232         CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(0)
24233                                           : OtherSVN->getOperand(1);
24234       } else {
24235         // This shuffle index references an element within N1.
24236         CurrentVec = N1;
24237       }
24238 
24239       // Simple case where 'CurrentVec' is UNDEF.
24240       if (CurrentVec.isUndef()) {
24241         Mask.push_back(-1);
24242         continue;
24243       }
24244 
24245       // Canonicalize the shuffle index. We don't know yet if CurrentVec
24246       // will be the first or second operand of the combined shuffle.
24247       Idx = Idx % NumElts;
24248       if (!SV0.getNode() || SV0 == CurrentVec) {
24249         // Ok. CurrentVec is the left hand side.
24250         // Update the mask accordingly.
24251         SV0 = CurrentVec;
24252         Mask.push_back(Idx);
24253         continue;
24254       }
24255       if (!SV1.getNode() || SV1 == CurrentVec) {
24256         // Ok. CurrentVec is the right hand side.
24257         // Update the mask accordingly.
24258         SV1 = CurrentVec;
24259         Mask.push_back(Idx + NumElts);
24260         continue;
24261       }
24262 
24263       // Last chance - see if the vector is another shuffle and if it
24264       // uses one of the existing candidate shuffle ops.
24265       if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(CurrentVec)) {
24266         int InnerIdx = CurrentSVN->getMaskElt(Idx);
24267         if (InnerIdx < 0) {
24268           Mask.push_back(-1);
24269           continue;
24270         }
24271         SDValue InnerVec = (InnerIdx < (int)NumElts)
24272                                ? CurrentSVN->getOperand(0)
24273                                : CurrentSVN->getOperand(1);
24274         if (InnerVec.isUndef()) {
24275           Mask.push_back(-1);
24276           continue;
24277         }
24278         InnerIdx %= NumElts;
24279         if (InnerVec == SV0) {
24280           Mask.push_back(InnerIdx);
24281           continue;
24282         }
24283         if (InnerVec == SV1) {
24284           Mask.push_back(InnerIdx + NumElts);
24285           continue;
24286         }
24287       }
24288 
24289       // Bail out if we cannot convert the shuffle pair into a single shuffle.
24290       return false;
24291     }
24292 
24293     if (llvm::all_of(Mask, [](int M) { return M < 0; }))
24294       return true;
24295 
24296     // Avoid introducing shuffles with illegal mask.
24297     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
24298     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
24299     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
24300     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
24301     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
24302     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
24303     if (TLI.isShuffleMaskLegal(Mask, VT))
24304       return true;
24305 
24306     std::swap(SV0, SV1);
24307     ShuffleVectorSDNode::commuteMask(Mask);
24308     return TLI.isShuffleMaskLegal(Mask, VT);
24309   };
24310 
24311   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
24312     // Canonicalize shuffles according to rules:
24313     //  shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
24314     //  shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
24315     //  shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
24316     if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
24317         N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
24318       // The incoming shuffle must be of the same type as the result of the
24319       // current shuffle.
24320       assert(N1->getOperand(0).getValueType() == VT &&
24321              "Shuffle types don't match");
24322 
24323       SDValue SV0 = N1->getOperand(0);
24324       SDValue SV1 = N1->getOperand(1);
24325       bool HasSameOp0 = N0 == SV0;
24326       bool IsSV1Undef = SV1.isUndef();
24327       if (HasSameOp0 || IsSV1Undef || N0 == SV1)
24328         // Commute the operands of this shuffle so merging below will trigger.
24329         return DAG.getCommutedVectorShuffle(*SVN);
24330     }
24331 
24332     // Canonicalize splat shuffles to the RHS to improve merging below.
24333     //  shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
24334     if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
24335         N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
24336         cast<ShuffleVectorSDNode>(N0)->isSplat() &&
24337         !cast<ShuffleVectorSDNode>(N1)->isSplat()) {
24338       return DAG.getCommutedVectorShuffle(*SVN);
24339     }
24340 
24341     // Try to fold according to rules:
24342     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
24343     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
24344     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
24345     // Don't try to fold shuffles with illegal type.
24346     // Only fold if this shuffle is the only user of the other shuffle.
24347     // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
24348     for (int i = 0; i != 2; ++i) {
24349       if (N->getOperand(i).getOpcode() == ISD::VECTOR_SHUFFLE &&
24350           N->isOnlyUserOf(N->getOperand(i).getNode())) {
24351         // The incoming shuffle must be of the same type as the result of the
24352         // current shuffle.
24353         auto *OtherSV = cast<ShuffleVectorSDNode>(N->getOperand(i));
24354         assert(OtherSV->getOperand(0).getValueType() == VT &&
24355                "Shuffle types don't match");
24356 
24357         SDValue SV0, SV1;
24358         SmallVector<int, 4> Mask;
24359         if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(1 - i), TLI,
24360                               SV0, SV1, Mask)) {
24361           // Check if all indices in Mask are Undef. In case, propagate Undef.
24362           if (llvm::all_of(Mask, [](int M) { return M < 0; }))
24363             return DAG.getUNDEF(VT);
24364 
24365           return DAG.getVectorShuffle(VT, SDLoc(N),
24366                                       SV0 ? SV0 : DAG.getUNDEF(VT),
24367                                       SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
24368         }
24369       }
24370     }
24371 
24372     // Merge shuffles through binops if we are able to merge it with at least
24373     // one other shuffles.
24374     // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
24375     // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
24376     unsigned SrcOpcode = N0.getOpcode();
24377     if (TLI.isBinOp(SrcOpcode) && N->isOnlyUserOf(N0.getNode()) &&
24378         (N1.isUndef() ||
24379          (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N1.getNode())))) {
24380       // Get binop source ops, or just pass on the undef.
24381       SDValue Op00 = N0.getOperand(0);
24382       SDValue Op01 = N0.getOperand(1);
24383       SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(0);
24384       SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(1);
24385       // TODO: We might be able to relax the VT check but we don't currently
24386       // have any isBinOp() that has different result/ops VTs so play safe until
24387       // we have test coverage.
24388       if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
24389           Op01.getValueType() == VT && Op11.getValueType() == VT &&
24390           (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
24391            Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
24392            Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
24393            Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
24394         auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
24395                                         SmallVectorImpl<int> &Mask, bool LeftOp,
24396                                         bool Commute) {
24397           SDValue InnerN = Commute ? N1 : N0;
24398           SDValue Op0 = LeftOp ? Op00 : Op01;
24399           SDValue Op1 = LeftOp ? Op10 : Op11;
24400           if (Commute)
24401             std::swap(Op0, Op1);
24402           // Only accept the merged shuffle if we don't introduce undef elements,
24403           // or the inner shuffle already contained undef elements.
24404           auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Op0);
24405           return SVN0 && InnerN->isOnlyUserOf(SVN0) &&
24406                  MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
24407                                    Mask) &&
24408                  (llvm::any_of(SVN0->getMask(), [](int M) { return M < 0; }) ||
24409                   llvm::none_of(Mask, [](int M) { return M < 0; }));
24410         };
24411 
24412         // Ensure we don't increase the number of shuffles - we must merge a
24413         // shuffle from at least one of the LHS and RHS ops.
24414         bool MergedLeft = false;
24415         SDValue LeftSV0, LeftSV1;
24416         SmallVector<int, 4> LeftMask;
24417         if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
24418             CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
24419           MergedLeft = true;
24420         } else {
24421           LeftMask.assign(SVN->getMask().begin(), SVN->getMask().end());
24422           LeftSV0 = Op00, LeftSV1 = Op10;
24423         }
24424 
24425         bool MergedRight = false;
24426         SDValue RightSV0, RightSV1;
24427         SmallVector<int, 4> RightMask;
24428         if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
24429             CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
24430           MergedRight = true;
24431         } else {
24432           RightMask.assign(SVN->getMask().begin(), SVN->getMask().end());
24433           RightSV0 = Op01, RightSV1 = Op11;
24434         }
24435 
24436         if (MergedLeft || MergedRight) {
24437           SDLoc DL(N);
24438           SDValue LHS = DAG.getVectorShuffle(
24439               VT, DL, LeftSV0 ? LeftSV0 : DAG.getUNDEF(VT),
24440               LeftSV1 ? LeftSV1 : DAG.getUNDEF(VT), LeftMask);
24441           SDValue RHS = DAG.getVectorShuffle(
24442               VT, DL, RightSV0 ? RightSV0 : DAG.getUNDEF(VT),
24443               RightSV1 ? RightSV1 : DAG.getUNDEF(VT), RightMask);
24444           return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
24445         }
24446       }
24447     }
24448   }
24449 
24450   if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
24451     return V;
24452 
24453   // Match shuffles that can be converted to ISD::ZERO_EXTEND_VECTOR_INREG.
24454   // Perform this really late, because it could eliminate knowledge
24455   // of undef elements created by this shuffle.
24456   if (Level < AfterLegalizeTypes)
24457     if (SDValue V = combineShuffleToZeroExtendVectorInReg(SVN, DAG, TLI,
24458                                                           LegalOperations))
24459       return V;
24460 
24461   return SDValue();
24462 }
24463 
visitSCALAR_TO_VECTOR(SDNode * N)24464 SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
24465   EVT VT = N->getValueType(0);
24466   if (!VT.isFixedLengthVector())
24467     return SDValue();
24468 
24469   // Try to convert a scalar binop with an extracted vector element to a vector
24470   // binop. This is intended to reduce potentially expensive register moves.
24471   // TODO: Check if both operands are extracted.
24472   // TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT().
24473   SDValue Scalar = N->getOperand(0);
24474   unsigned Opcode = Scalar.getOpcode();
24475   EVT VecEltVT = VT.getScalarType();
24476   if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 &&
24477       TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT &&
24478       Scalar.getOperand(0).getValueType() == VecEltVT &&
24479       Scalar.getOperand(1).getValueType() == VecEltVT &&
24480       DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) {
24481     // Match an extract element and get a shuffle mask equivalent.
24482     SmallVector<int, 8> ShufMask(VT.getVectorNumElements(), -1);
24483 
24484     for (int i : {0, 1}) {
24485       // s2v (bo (extelt V, Idx), C) --> shuffle (bo V, C'), {Idx, -1, -1...}
24486       // s2v (bo C, (extelt V, Idx)) --> shuffle (bo C', V), {Idx, -1, -1...}
24487       SDValue EE = Scalar.getOperand(i);
24488       auto *C = dyn_cast<ConstantSDNode>(Scalar.getOperand(i ? 0 : 1));
24489       if (C && EE.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
24490           EE.getOperand(0).getValueType() == VT &&
24491           isa<ConstantSDNode>(EE.getOperand(1))) {
24492         // Mask = {ExtractIndex, undef, undef....}
24493         ShufMask[0] = EE.getConstantOperandVal(1);
24494         // Make sure the shuffle is legal if we are crossing lanes.
24495         if (TLI.isShuffleMaskLegal(ShufMask, VT)) {
24496           SDLoc DL(N);
24497           SDValue V[] = {EE.getOperand(0),
24498                          DAG.getConstant(C->getAPIntValue(), DL, VT)};
24499           SDValue VecBO = DAG.getNode(Opcode, DL, VT, V[i], V[1 - i]);
24500           return DAG.getVectorShuffle(VT, DL, VecBO, DAG.getUNDEF(VT),
24501                                       ShufMask);
24502         }
24503       }
24504     }
24505   }
24506 
24507   // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
24508   // with a VECTOR_SHUFFLE and possible truncate.
24509   if (Opcode != ISD::EXTRACT_VECTOR_ELT ||
24510       !Scalar.getOperand(0).getValueType().isFixedLengthVector())
24511     return SDValue();
24512 
24513   // If we have an implicit truncate, truncate here if it is legal.
24514   if (VecEltVT != Scalar.getValueType() &&
24515       Scalar.getValueType().isScalarInteger() && isTypeLegal(VecEltVT)) {
24516     SDValue Val = DAG.getNode(ISD::TRUNCATE, SDLoc(Scalar), VecEltVT, Scalar);
24517     return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
24518   }
24519 
24520   auto *ExtIndexC = dyn_cast<ConstantSDNode>(Scalar.getOperand(1));
24521   if (!ExtIndexC)
24522     return SDValue();
24523 
24524   SDValue SrcVec = Scalar.getOperand(0);
24525   EVT SrcVT = SrcVec.getValueType();
24526   unsigned SrcNumElts = SrcVT.getVectorNumElements();
24527   unsigned VTNumElts = VT.getVectorNumElements();
24528   if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
24529     // Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...}
24530     SmallVector<int, 8> Mask(SrcNumElts, -1);
24531     Mask[0] = ExtIndexC->getZExtValue();
24532     SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
24533         SrcVT, SDLoc(N), SrcVec, DAG.getUNDEF(SrcVT), Mask, DAG);
24534     if (!LegalShuffle)
24535       return SDValue();
24536 
24537     // If the initial vector is the same size, the shuffle is the result.
24538     if (VT == SrcVT)
24539       return LegalShuffle;
24540 
24541     // If not, shorten the shuffled vector.
24542     if (VTNumElts != SrcNumElts) {
24543       SDValue ZeroIdx = DAG.getVectorIdxConstant(0, SDLoc(N));
24544       EVT SubVT = EVT::getVectorVT(*DAG.getContext(),
24545                                    SrcVT.getVectorElementType(), VTNumElts);
24546       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT, LegalShuffle,
24547                          ZeroIdx);
24548     }
24549   }
24550 
24551   return SDValue();
24552 }
24553 
visitINSERT_SUBVECTOR(SDNode * N)24554 SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
24555   EVT VT = N->getValueType(0);
24556   SDValue N0 = N->getOperand(0);
24557   SDValue N1 = N->getOperand(1);
24558   SDValue N2 = N->getOperand(2);
24559   uint64_t InsIdx = N->getConstantOperandVal(2);
24560 
24561   // If inserting an UNDEF, just return the original vector.
24562   if (N1.isUndef())
24563     return N0;
24564 
24565   // If this is an insert of an extracted vector into an undef vector, we can
24566   // just use the input to the extract.
24567   if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
24568       N1.getOperand(1) == N2 && N1.getOperand(0).getValueType() == VT)
24569     return N1.getOperand(0);
24570 
24571   // Simplify scalar inserts into an undef vector:
24572   // insert_subvector undef, (splat X), N2 -> splat X
24573   if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR)
24574     return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, N1.getOperand(0));
24575 
24576   // If we are inserting a bitcast value into an undef, with the same
24577   // number of elements, just use the bitcast input of the extract.
24578   // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
24579   //        BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
24580   if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
24581       N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
24582       N1.getOperand(0).getOperand(1) == N2 &&
24583       N1.getOperand(0).getOperand(0).getValueType().getVectorElementCount() ==
24584           VT.getVectorElementCount() &&
24585       N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() ==
24586           VT.getSizeInBits()) {
24587     return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0));
24588   }
24589 
24590   // If both N1 and N2 are bitcast values on which insert_subvector
24591   // would makes sense, pull the bitcast through.
24592   // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
24593   //        BITCAST (INSERT_SUBVECTOR N0 N1 N2)
24594   if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
24595     SDValue CN0 = N0.getOperand(0);
24596     SDValue CN1 = N1.getOperand(0);
24597     EVT CN0VT = CN0.getValueType();
24598     EVT CN1VT = CN1.getValueType();
24599     if (CN0VT.isVector() && CN1VT.isVector() &&
24600         CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
24601         CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
24602       SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
24603                                       CN0.getValueType(), CN0, CN1, N2);
24604       return DAG.getBitcast(VT, NewINSERT);
24605     }
24606   }
24607 
24608   // Combine INSERT_SUBVECTORs where we are inserting to the same index.
24609   // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
24610   // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
24611   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
24612       N0.getOperand(1).getValueType() == N1.getValueType() &&
24613       N0.getOperand(2) == N2)
24614     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
24615                        N1, N2);
24616 
24617   // Eliminate an intermediate insert into an undef vector:
24618   // insert_subvector undef, (insert_subvector undef, X, 0), N2 -->
24619   // insert_subvector undef, X, N2
24620   if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
24621       N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2)))
24622     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0,
24623                        N1.getOperand(1), N2);
24624 
24625   // Push subvector bitcasts to the output, adjusting the index as we go.
24626   // insert_subvector(bitcast(v), bitcast(s), c1)
24627   // -> bitcast(insert_subvector(v, s, c2))
24628   if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
24629       N1.getOpcode() == ISD::BITCAST) {
24630     SDValue N0Src = peekThroughBitcasts(N0);
24631     SDValue N1Src = peekThroughBitcasts(N1);
24632     EVT N0SrcSVT = N0Src.getValueType().getScalarType();
24633     EVT N1SrcSVT = N1Src.getValueType().getScalarType();
24634     if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
24635         N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
24636       EVT NewVT;
24637       SDLoc DL(N);
24638       SDValue NewIdx;
24639       LLVMContext &Ctx = *DAG.getContext();
24640       ElementCount NumElts = VT.getVectorElementCount();
24641       unsigned EltSizeInBits = VT.getScalarSizeInBits();
24642       if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
24643         unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
24644         NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts * Scale);
24645         NewIdx = DAG.getVectorIdxConstant(InsIdx * Scale, DL);
24646       } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
24647         unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
24648         if (NumElts.isKnownMultipleOf(Scale) && (InsIdx % Scale) == 0) {
24649           NewVT = EVT::getVectorVT(Ctx, N1SrcSVT,
24650                                    NumElts.divideCoefficientBy(Scale));
24651           NewIdx = DAG.getVectorIdxConstant(InsIdx / Scale, DL);
24652         }
24653       }
24654       if (NewIdx && hasOperation(ISD::INSERT_SUBVECTOR, NewVT)) {
24655         SDValue Res = DAG.getBitcast(NewVT, N0Src);
24656         Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, Res, N1Src, NewIdx);
24657         return DAG.getBitcast(VT, Res);
24658       }
24659     }
24660   }
24661 
24662   // Canonicalize insert_subvector dag nodes.
24663   // Example:
24664   // (insert_subvector (insert_subvector A, Idx0), Idx1)
24665   // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
24666   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
24667       N1.getValueType() == N0.getOperand(1).getValueType()) {
24668     unsigned OtherIdx = N0.getConstantOperandVal(2);
24669     if (InsIdx < OtherIdx) {
24670       // Swap nodes.
24671       SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
24672                                   N0.getOperand(0), N1, N2);
24673       AddToWorklist(NewOp.getNode());
24674       return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
24675                          VT, NewOp, N0.getOperand(1), N0.getOperand(2));
24676     }
24677   }
24678 
24679   // If the input vector is a concatenation, and the insert replaces
24680   // one of the pieces, we can optimize into a single concat_vectors.
24681   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
24682       N0.getOperand(0).getValueType() == N1.getValueType() &&
24683       N0.getOperand(0).getValueType().isScalableVector() ==
24684           N1.getValueType().isScalableVector()) {
24685     unsigned Factor = N1.getValueType().getVectorMinNumElements();
24686     SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
24687     Ops[InsIdx / Factor] = N1;
24688     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
24689   }
24690 
24691   // Simplify source operands based on insertion.
24692   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
24693     return SDValue(N, 0);
24694 
24695   return SDValue();
24696 }
24697 
visitFP_TO_FP16(SDNode * N)24698 SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
24699   SDValue N0 = N->getOperand(0);
24700 
24701   // fold (fp_to_fp16 (fp16_to_fp op)) -> op
24702   if (N0->getOpcode() == ISD::FP16_TO_FP)
24703     return N0->getOperand(0);
24704 
24705   return SDValue();
24706 }
24707 
visitFP16_TO_FP(SDNode * N)24708 SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
24709   SDValue N0 = N->getOperand(0);
24710 
24711   // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op)
24712   if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
24713     ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
24714     if (AndConst && AndConst->getAPIntValue() == 0xffff) {
24715       return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), N->getValueType(0),
24716                          N0.getOperand(0));
24717     }
24718   }
24719 
24720   return SDValue();
24721 }
24722 
visitFP_TO_BF16(SDNode * N)24723 SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
24724   SDValue N0 = N->getOperand(0);
24725 
24726   // fold (fp_to_bf16 (bf16_to_fp op)) -> op
24727   if (N0->getOpcode() == ISD::BF16_TO_FP)
24728     return N0->getOperand(0);
24729 
24730   return SDValue();
24731 }
24732 
visitVECREDUCE(SDNode * N)24733 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
24734   SDValue N0 = N->getOperand(0);
24735   EVT VT = N0.getValueType();
24736   unsigned Opcode = N->getOpcode();
24737 
24738   // VECREDUCE over 1-element vector is just an extract.
24739   if (VT.getVectorElementCount().isScalar()) {
24740     SDLoc dl(N);
24741     SDValue Res =
24742         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
24743                     DAG.getVectorIdxConstant(0, dl));
24744     if (Res.getValueType() != N->getValueType(0))
24745       Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
24746     return Res;
24747   }
24748 
24749   // On an boolean vector an and/or reduction is the same as a umin/umax
24750   // reduction. Convert them if the latter is legal while the former isn't.
24751   if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
24752     unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
24753         ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
24754     if (!TLI.isOperationLegalOrCustom(Opcode, VT) &&
24755         TLI.isOperationLegalOrCustom(NewOpcode, VT) &&
24756         DAG.ComputeNumSignBits(N0) == VT.getScalarSizeInBits())
24757       return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), N0);
24758   }
24759 
24760   // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val)
24761   // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val)
24762   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
24763       TLI.isTypeLegal(N0.getOperand(1).getValueType())) {
24764     SDValue Vec = N0.getOperand(0);
24765     SDValue Subvec = N0.getOperand(1);
24766     if ((Opcode == ISD::VECREDUCE_OR &&
24767          (N0.getOperand(0).isUndef() || isNullOrNullSplat(Vec))) ||
24768         (Opcode == ISD::VECREDUCE_AND &&
24769          (N0.getOperand(0).isUndef() || isAllOnesOrAllOnesSplat(Vec))))
24770       return DAG.getNode(Opcode, SDLoc(N), N->getValueType(0), Subvec);
24771   }
24772 
24773   return SDValue();
24774 }
24775 
visitVPOp(SDNode * N)24776 SDValue DAGCombiner::visitVPOp(SDNode *N) {
24777 
24778   if (N->getOpcode() == ISD::VP_GATHER)
24779     if (SDValue SD = visitVPGATHER(N))
24780       return SD;
24781 
24782   if (N->getOpcode() == ISD::VP_SCATTER)
24783     if (SDValue SD = visitVPSCATTER(N))
24784       return SD;
24785 
24786   // VP operations in which all vector elements are disabled - either by
24787   // determining that the mask is all false or that the EVL is 0 - can be
24788   // eliminated.
24789   bool AreAllEltsDisabled = false;
24790   if (auto EVLIdx = ISD::getVPExplicitVectorLengthIdx(N->getOpcode()))
24791     AreAllEltsDisabled |= isNullConstant(N->getOperand(*EVLIdx));
24792   if (auto MaskIdx = ISD::getVPMaskIdx(N->getOpcode()))
24793     AreAllEltsDisabled |=
24794         ISD::isConstantSplatVectorAllZeros(N->getOperand(*MaskIdx).getNode());
24795 
24796   // This is the only generic VP combine we support for now.
24797   if (!AreAllEltsDisabled)
24798     return SDValue();
24799 
24800   // Binary operations can be replaced by UNDEF.
24801   if (ISD::isVPBinaryOp(N->getOpcode()))
24802     return DAG.getUNDEF(N->getValueType(0));
24803 
24804   // VP Memory operations can be replaced by either the chain (stores) or the
24805   // chain + undef (loads).
24806   if (const auto *MemSD = dyn_cast<MemSDNode>(N)) {
24807     if (MemSD->writeMem())
24808       return MemSD->getChain();
24809     return CombineTo(N, DAG.getUNDEF(N->getValueType(0)), MemSD->getChain());
24810   }
24811 
24812   // Reduction operations return the start operand when no elements are active.
24813   if (ISD::isVPReduction(N->getOpcode()))
24814     return N->getOperand(0);
24815 
24816   return SDValue();
24817 }
24818 
24819 /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
24820 /// with the destination vector and a zero vector.
24821 /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
24822 ///      vector_shuffle V, Zero, <0, 4, 2, 4>
XformToShuffleWithZero(SDNode * N)24823 SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
24824   assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
24825 
24826   EVT VT = N->getValueType(0);
24827   SDValue LHS = N->getOperand(0);
24828   SDValue RHS = peekThroughBitcasts(N->getOperand(1));
24829   SDLoc DL(N);
24830 
24831   // Make sure we're not running after operation legalization where it
24832   // may have custom lowered the vector shuffles.
24833   if (LegalOperations)
24834     return SDValue();
24835 
24836   if (RHS.getOpcode() != ISD::BUILD_VECTOR)
24837     return SDValue();
24838 
24839   EVT RVT = RHS.getValueType();
24840   unsigned NumElts = RHS.getNumOperands();
24841 
24842   // Attempt to create a valid clear mask, splitting the mask into
24843   // sub elements and checking to see if each is
24844   // all zeros or all ones - suitable for shuffle masking.
24845   auto BuildClearMask = [&](int Split) {
24846     int NumSubElts = NumElts * Split;
24847     int NumSubBits = RVT.getScalarSizeInBits() / Split;
24848 
24849     SmallVector<int, 8> Indices;
24850     for (int i = 0; i != NumSubElts; ++i) {
24851       int EltIdx = i / Split;
24852       int SubIdx = i % Split;
24853       SDValue Elt = RHS.getOperand(EltIdx);
24854       // X & undef --> 0 (not undef). So this lane must be converted to choose
24855       // from the zero constant vector (same as if the element had all 0-bits).
24856       if (Elt.isUndef()) {
24857         Indices.push_back(i + NumSubElts);
24858         continue;
24859       }
24860 
24861       APInt Bits;
24862       if (isa<ConstantSDNode>(Elt))
24863         Bits = cast<ConstantSDNode>(Elt)->getAPIntValue();
24864       else if (isa<ConstantFPSDNode>(Elt))
24865         Bits = cast<ConstantFPSDNode>(Elt)->getValueAPF().bitcastToAPInt();
24866       else
24867         return SDValue();
24868 
24869       // Extract the sub element from the constant bit mask.
24870       if (DAG.getDataLayout().isBigEndian())
24871         Bits = Bits.extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits);
24872       else
24873         Bits = Bits.extractBits(NumSubBits, SubIdx * NumSubBits);
24874 
24875       if (Bits.isAllOnes())
24876         Indices.push_back(i);
24877       else if (Bits == 0)
24878         Indices.push_back(i + NumSubElts);
24879       else
24880         return SDValue();
24881     }
24882 
24883     // Let's see if the target supports this vector_shuffle.
24884     EVT ClearSVT = EVT::getIntegerVT(*DAG.getContext(), NumSubBits);
24885     EVT ClearVT = EVT::getVectorVT(*DAG.getContext(), ClearSVT, NumSubElts);
24886     if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
24887       return SDValue();
24888 
24889     SDValue Zero = DAG.getConstant(0, DL, ClearVT);
24890     return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL,
24891                                                    DAG.getBitcast(ClearVT, LHS),
24892                                                    Zero, Indices));
24893   };
24894 
24895   // Determine maximum split level (byte level masking).
24896   int MaxSplit = 1;
24897   if (RVT.getScalarSizeInBits() % 8 == 0)
24898     MaxSplit = RVT.getScalarSizeInBits() / 8;
24899 
24900   for (int Split = 1; Split <= MaxSplit; ++Split)
24901     if (RVT.getScalarSizeInBits() % Split == 0)
24902       if (SDValue S = BuildClearMask(Split))
24903         return S;
24904 
24905   return SDValue();
24906 }
24907 
24908 /// If a vector binop is performed on splat values, it may be profitable to
24909 /// extract, scalarize, and insert/splat.
scalarizeBinOpOfSplats(SDNode * N,SelectionDAG & DAG,const SDLoc & DL)24910 static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
24911                                       const SDLoc &DL) {
24912   SDValue N0 = N->getOperand(0);
24913   SDValue N1 = N->getOperand(1);
24914   unsigned Opcode = N->getOpcode();
24915   EVT VT = N->getValueType(0);
24916   EVT EltVT = VT.getVectorElementType();
24917   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24918 
24919   // TODO: Remove/replace the extract cost check? If the elements are available
24920   //       as scalars, then there may be no extract cost. Should we ask if
24921   //       inserting a scalar back into a vector is cheap instead?
24922   int Index0, Index1;
24923   SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
24924   SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
24925   // Extract element from splat_vector should be free.
24926   // TODO: use DAG.isSplatValue instead?
24927   bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
24928                            N1.getOpcode() == ISD::SPLAT_VECTOR;
24929   if (!Src0 || !Src1 || Index0 != Index1 ||
24930       Src0.getValueType().getVectorElementType() != EltVT ||
24931       Src1.getValueType().getVectorElementType() != EltVT ||
24932       !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index0)) ||
24933       !TLI.isOperationLegalOrCustom(Opcode, EltVT))
24934     return SDValue();
24935 
24936   SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
24937   SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src0, IndexC);
24938   SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src1, IndexC);
24939   SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags());
24940 
24941   // If all lanes but 1 are undefined, no need to splat the scalar result.
24942   // TODO: Keep track of undefs and use that info in the general case.
24943   if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
24944       count_if(N0->ops(), [](SDValue V) { return !V.isUndef(); }) == 1 &&
24945       count_if(N1->ops(), [](SDValue V) { return !V.isUndef(); }) == 1) {
24946     // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
24947     // build_vec ..undef, (bo X, Y), undef...
24948     SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(EltVT));
24949     Ops[Index0] = ScalarBO;
24950     return DAG.getBuildVector(VT, DL, Ops);
24951   }
24952 
24953   // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
24954   return DAG.getSplat(VT, DL, ScalarBO);
24955 }
24956 
24957 /// Visit a vector cast operation, like FP_EXTEND.
SimplifyVCastOp(SDNode * N,const SDLoc & DL)24958 SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) {
24959   EVT VT = N->getValueType(0);
24960   assert(VT.isVector() && "SimplifyVCastOp only works on vectors!");
24961   EVT EltVT = VT.getVectorElementType();
24962   unsigned Opcode = N->getOpcode();
24963 
24964   SDValue N0 = N->getOperand(0);
24965   EVT SrcVT = N0->getValueType(0);
24966   EVT SrcEltVT = SrcVT.getVectorElementType();
24967   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24968 
24969   // TODO: promote operation might be also good here?
24970   int Index0;
24971   SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
24972   if (Src0 &&
24973       (N0.getOpcode() == ISD::SPLAT_VECTOR ||
24974        TLI.isExtractVecEltCheap(VT, Index0)) &&
24975       TLI.isOperationLegalOrCustom(Opcode, EltVT) &&
24976       TLI.preferScalarizeSplat(Opcode)) {
24977     SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
24978     SDValue Elt =
24979         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SrcEltVT, Src0, IndexC);
24980     SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, Elt, N->getFlags());
24981     if (VT.isScalableVector())
24982       return DAG.getSplatVector(VT, DL, ScalarBO);
24983     SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
24984     return DAG.getBuildVector(VT, DL, Ops);
24985   }
24986 
24987   return SDValue();
24988 }
24989 
24990 /// Visit a binary vector operation, like ADD.
SimplifyVBinOp(SDNode * N,const SDLoc & DL)24991 SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
24992   EVT VT = N->getValueType(0);
24993   assert(VT.isVector() && "SimplifyVBinOp only works on vectors!");
24994 
24995   SDValue LHS = N->getOperand(0);
24996   SDValue RHS = N->getOperand(1);
24997   unsigned Opcode = N->getOpcode();
24998   SDNodeFlags Flags = N->getFlags();
24999 
25000   // Move unary shuffles with identical masks after a vector binop:
25001   // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
25002   //   --> shuffle (VBinOp A, B), Undef, Mask
25003   // This does not require type legality checks because we are creating the
25004   // same types of operations that are in the original sequence. We do have to
25005   // restrict ops like integer div that have immediate UB (eg, div-by-zero)
25006   // though. This code is adapted from the identical transform in instcombine.
25007   if (DAG.isSafeToSpeculativelyExecute(Opcode)) {
25008     auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(LHS);
25009     auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(RHS);
25010     if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) &&
25011         LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() &&
25012         (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
25013       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS.getOperand(0),
25014                                      RHS.getOperand(0), Flags);
25015       SDValue UndefV = LHS.getOperand(1);
25016       return DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask());
25017     }
25018 
25019     // Try to sink a splat shuffle after a binop with a uniform constant.
25020     // This is limited to cases where neither the shuffle nor the constant have
25021     // undefined elements because that could be poison-unsafe or inhibit
25022     // demanded elements analysis. It is further limited to not change a splat
25023     // of an inserted scalar because that may be optimized better by
25024     // load-folding or other target-specific behaviors.
25025     if (isConstOrConstSplat(RHS) && Shuf0 && all_equal(Shuf0->getMask()) &&
25026         Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() &&
25027         Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
25028       // binop (splat X), (splat C) --> splat (binop X, C)
25029       SDValue X = Shuf0->getOperand(0);
25030       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, X, RHS, Flags);
25031       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
25032                                   Shuf0->getMask());
25033     }
25034     if (isConstOrConstSplat(LHS) && Shuf1 && all_equal(Shuf1->getMask()) &&
25035         Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() &&
25036         Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
25037       // binop (splat C), (splat X) --> splat (binop C, X)
25038       SDValue X = Shuf1->getOperand(0);
25039       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS, X, Flags);
25040       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
25041                                   Shuf1->getMask());
25042     }
25043   }
25044 
25045   // The following pattern is likely to emerge with vector reduction ops. Moving
25046   // the binary operation ahead of insertion may allow using a narrower vector
25047   // instruction that has better performance than the wide version of the op:
25048   // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
25049   if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(0).isUndef() &&
25050       RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(0).isUndef() &&
25051       LHS.getOperand(2) == RHS.getOperand(2) &&
25052       (LHS.hasOneUse() || RHS.hasOneUse())) {
25053     SDValue X = LHS.getOperand(1);
25054     SDValue Y = RHS.getOperand(1);
25055     SDValue Z = LHS.getOperand(2);
25056     EVT NarrowVT = X.getValueType();
25057     if (NarrowVT == Y.getValueType() &&
25058         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT,
25059                                               LegalOperations)) {
25060       // (binop undef, undef) may not return undef, so compute that result.
25061       SDValue VecC =
25062           DAG.getNode(Opcode, DL, VT, DAG.getUNDEF(VT), DAG.getUNDEF(VT));
25063       SDValue NarrowBO = DAG.getNode(Opcode, DL, NarrowVT, X, Y);
25064       return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, VecC, NarrowBO, Z);
25065     }
25066   }
25067 
25068   // Make sure all but the first op are undef or constant.
25069   auto ConcatWithConstantOrUndef = [](SDValue Concat) {
25070     return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
25071            all_of(drop_begin(Concat->ops()), [](const SDValue &Op) {
25072              return Op.isUndef() ||
25073                     ISD::isBuildVectorOfConstantSDNodes(Op.getNode());
25074            });
25075   };
25076 
25077   // The following pattern is likely to emerge with vector reduction ops. Moving
25078   // the binary operation ahead of the concat may allow using a narrower vector
25079   // instruction that has better performance than the wide version of the op:
25080   // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
25081   //   concat (VBinOp X, Y), VecC
25082   if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
25083       (LHS.hasOneUse() || RHS.hasOneUse())) {
25084     EVT NarrowVT = LHS.getOperand(0).getValueType();
25085     if (NarrowVT == RHS.getOperand(0).getValueType() &&
25086         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
25087       unsigned NumOperands = LHS.getNumOperands();
25088       SmallVector<SDValue, 4> ConcatOps;
25089       for (unsigned i = 0; i != NumOperands; ++i) {
25090         // This constant fold for operands 1 and up.
25091         ConcatOps.push_back(DAG.getNode(Opcode, DL, NarrowVT, LHS.getOperand(i),
25092                                         RHS.getOperand(i)));
25093       }
25094 
25095       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
25096     }
25097   }
25098 
25099   if (SDValue V = scalarizeBinOpOfSplats(N, DAG, DL))
25100     return V;
25101 
25102   return SDValue();
25103 }
25104 
SimplifySelect(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2)25105 SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
25106                                     SDValue N2) {
25107   assert(N0.getOpcode() == ISD::SETCC &&
25108          "First argument must be a SetCC node!");
25109 
25110   SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2,
25111                                  cast<CondCodeSDNode>(N0.getOperand(2))->get());
25112 
25113   // If we got a simplified select_cc node back from SimplifySelectCC, then
25114   // break it down into a new SETCC node, and a new SELECT node, and then return
25115   // the SELECT node, since we were called with a SELECT node.
25116   if (SCC.getNode()) {
25117     // Check to see if we got a select_cc back (to turn into setcc/select).
25118     // Otherwise, just return whatever node we got back, like fabs.
25119     if (SCC.getOpcode() == ISD::SELECT_CC) {
25120       const SDNodeFlags Flags = N0->getFlags();
25121       SDValue SETCC = DAG.getNode(ISD::SETCC, SDLoc(N0),
25122                                   N0.getValueType(),
25123                                   SCC.getOperand(0), SCC.getOperand(1),
25124                                   SCC.getOperand(4), Flags);
25125       AddToWorklist(SETCC.getNode());
25126       SDValue SelectNode = DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC,
25127                                          SCC.getOperand(2), SCC.getOperand(3));
25128       SelectNode->setFlags(Flags);
25129       return SelectNode;
25130     }
25131 
25132     return SCC;
25133   }
25134   return SDValue();
25135 }
25136 
25137 /// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
25138 /// being selected between, see if we can simplify the select.  Callers of this
25139 /// should assume that TheSelect is deleted if this returns true.  As such, they
25140 /// should return the appropriate thing (e.g. the node) back to the top-level of
25141 /// the DAG combiner loop to avoid it being looked at.
SimplifySelectOps(SDNode * TheSelect,SDValue LHS,SDValue RHS)25142 bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
25143                                     SDValue RHS) {
25144   // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
25145   // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
25146   if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) {
25147     if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
25148       // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
25149       SDValue Sqrt = RHS;
25150       ISD::CondCode CC;
25151       SDValue CmpLHS;
25152       const ConstantFPSDNode *Zero = nullptr;
25153 
25154       if (TheSelect->getOpcode() == ISD::SELECT_CC) {
25155         CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
25156         CmpLHS = TheSelect->getOperand(0);
25157         Zero = isConstOrConstSplatFP(TheSelect->getOperand(1));
25158       } else {
25159         // SELECT or VSELECT
25160         SDValue Cmp = TheSelect->getOperand(0);
25161         if (Cmp.getOpcode() == ISD::SETCC) {
25162           CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
25163           CmpLHS = Cmp.getOperand(0);
25164           Zero = isConstOrConstSplatFP(Cmp.getOperand(1));
25165         }
25166       }
25167       if (Zero && Zero->isZero() &&
25168           Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT ||
25169           CC == ISD::SETULT || CC == ISD::SETLT)) {
25170         // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
25171         CombineTo(TheSelect, Sqrt);
25172         return true;
25173       }
25174     }
25175   }
25176   // Cannot simplify select with vector condition
25177   if (TheSelect->getOperand(0).getValueType().isVector()) return false;
25178 
25179   // If this is a select from two identical things, try to pull the operation
25180   // through the select.
25181   if (LHS.getOpcode() != RHS.getOpcode() ||
25182       !LHS.hasOneUse() || !RHS.hasOneUse())
25183     return false;
25184 
25185   // If this is a load and the token chain is identical, replace the select
25186   // of two loads with a load through a select of the address to load from.
25187   // This triggers in things like "select bool X, 10.0, 123.0" after the FP
25188   // constants have been dropped into the constant pool.
25189   if (LHS.getOpcode() == ISD::LOAD) {
25190     LoadSDNode *LLD = cast<LoadSDNode>(LHS);
25191     LoadSDNode *RLD = cast<LoadSDNode>(RHS);
25192 
25193     // Token chains must be identical.
25194     if (LHS.getOperand(0) != RHS.getOperand(0) ||
25195         // Do not let this transformation reduce the number of volatile loads.
25196         // Be conservative for atomics for the moment
25197         // TODO: This does appear to be legal for unordered atomics (see D66309)
25198         !LLD->isSimple() || !RLD->isSimple() ||
25199         // FIXME: If either is a pre/post inc/dec load,
25200         // we'd need to split out the address adjustment.
25201         LLD->isIndexed() || RLD->isIndexed() ||
25202         // If this is an EXTLOAD, the VT's must match.
25203         LLD->getMemoryVT() != RLD->getMemoryVT() ||
25204         // If this is an EXTLOAD, the kind of extension must match.
25205         (LLD->getExtensionType() != RLD->getExtensionType() &&
25206          // The only exception is if one of the extensions is anyext.
25207          LLD->getExtensionType() != ISD::EXTLOAD &&
25208          RLD->getExtensionType() != ISD::EXTLOAD) ||
25209         // FIXME: this discards src value information.  This is
25210         // over-conservative. It would be beneficial to be able to remember
25211         // both potential memory locations.  Since we are discarding
25212         // src value info, don't do the transformation if the memory
25213         // locations are not in the default address space.
25214         LLD->getPointerInfo().getAddrSpace() != 0 ||
25215         RLD->getPointerInfo().getAddrSpace() != 0 ||
25216         // We can't produce a CMOV of a TargetFrameIndex since we won't
25217         // generate the address generation required.
25218         LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
25219         RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
25220         !TLI.isOperationLegalOrCustom(TheSelect->getOpcode(),
25221                                       LLD->getBasePtr().getValueType()))
25222       return false;
25223 
25224     // The loads must not depend on one another.
25225     if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD))
25226       return false;
25227 
25228     // Check that the select condition doesn't reach either load.  If so,
25229     // folding this will induce a cycle into the DAG.  If not, this is safe to
25230     // xform, so create a select of the addresses.
25231 
25232     SmallPtrSet<const SDNode *, 32> Visited;
25233     SmallVector<const SDNode *, 16> Worklist;
25234 
25235     // Always fail if LLD and RLD are not independent. TheSelect is a
25236     // predecessor to all Nodes in question so we need not search past it.
25237 
25238     Visited.insert(TheSelect);
25239     Worklist.push_back(LLD);
25240     Worklist.push_back(RLD);
25241 
25242     if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) ||
25243         SDNode::hasPredecessorHelper(RLD, Visited, Worklist))
25244       return false;
25245 
25246     SDValue Addr;
25247     if (TheSelect->getOpcode() == ISD::SELECT) {
25248       // We cannot do this optimization if any pair of {RLD, LLD} is a
25249       // predecessor to {RLD, LLD, CondNode}. As we've already compared the
25250       // Loads, we only need to check if CondNode is a successor to one of the
25251       // loads. We can further avoid this if there's no use of their chain
25252       // value.
25253       SDNode *CondNode = TheSelect->getOperand(0).getNode();
25254       Worklist.push_back(CondNode);
25255 
25256       if ((LLD->hasAnyUseOfValue(1) &&
25257            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
25258           (RLD->hasAnyUseOfValue(1) &&
25259            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
25260         return false;
25261 
25262       Addr = DAG.getSelect(SDLoc(TheSelect),
25263                            LLD->getBasePtr().getValueType(),
25264                            TheSelect->getOperand(0), LLD->getBasePtr(),
25265                            RLD->getBasePtr());
25266     } else {  // Otherwise SELECT_CC
25267       // We cannot do this optimization if any pair of {RLD, LLD} is a
25268       // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
25269       // the Loads, we only need to check if CondLHS/CondRHS is a successor to
25270       // one of the loads. We can further avoid this if there's no use of their
25271       // chain value.
25272 
25273       SDNode *CondLHS = TheSelect->getOperand(0).getNode();
25274       SDNode *CondRHS = TheSelect->getOperand(1).getNode();
25275       Worklist.push_back(CondLHS);
25276       Worklist.push_back(CondRHS);
25277 
25278       if ((LLD->hasAnyUseOfValue(1) &&
25279            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
25280           (RLD->hasAnyUseOfValue(1) &&
25281            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
25282         return false;
25283 
25284       Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect),
25285                          LLD->getBasePtr().getValueType(),
25286                          TheSelect->getOperand(0),
25287                          TheSelect->getOperand(1),
25288                          LLD->getBasePtr(), RLD->getBasePtr(),
25289                          TheSelect->getOperand(4));
25290     }
25291 
25292     SDValue Load;
25293     // It is safe to replace the two loads if they have different alignments,
25294     // but the new load must be the minimum (most restrictive) alignment of the
25295     // inputs.
25296     Align Alignment = std::min(LLD->getAlign(), RLD->getAlign());
25297     MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
25298     if (!RLD->isInvariant())
25299       MMOFlags &= ~MachineMemOperand::MOInvariant;
25300     if (!RLD->isDereferenceable())
25301       MMOFlags &= ~MachineMemOperand::MODereferenceable;
25302     if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
25303       // FIXME: Discards pointer and AA info.
25304       Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect),
25305                          LLD->getChain(), Addr, MachinePointerInfo(), Alignment,
25306                          MMOFlags);
25307     } else {
25308       // FIXME: Discards pointer and AA info.
25309       Load = DAG.getExtLoad(
25310           LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
25311                                                   : LLD->getExtensionType(),
25312           SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr,
25313           MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags);
25314     }
25315 
25316     // Users of the select now use the result of the load.
25317     CombineTo(TheSelect, Load);
25318 
25319     // Users of the old loads now use the new load's chain.  We know the
25320     // old-load value is dead now.
25321     CombineTo(LHS.getNode(), Load.getValue(0), Load.getValue(1));
25322     CombineTo(RHS.getNode(), Load.getValue(0), Load.getValue(1));
25323     return true;
25324   }
25325 
25326   return false;
25327 }
25328 
25329 /// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
25330 /// bitwise 'and'.
foldSelectCCToShiftAnd(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)25331 SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
25332                                             SDValue N1, SDValue N2, SDValue N3,
25333                                             ISD::CondCode CC) {
25334   // If this is a select where the false operand is zero and the compare is a
25335   // check of the sign bit, see if we can perform the "gzip trick":
25336   // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
25337   // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
25338   EVT XType = N0.getValueType();
25339   EVT AType = N2.getValueType();
25340   if (!isNullConstant(N3) || !XType.bitsGE(AType))
25341     return SDValue();
25342 
25343   // If the comparison is testing for a positive value, we have to invert
25344   // the sign bit mask, so only do that transform if the target has a bitwise
25345   // 'and not' instruction (the invert is free).
25346   if (CC == ISD::SETGT && TLI.hasAndNot(N2)) {
25347     // (X > -1) ? A : 0
25348     // (X >  0) ? X : 0 <-- This is canonical signed max.
25349     if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2)))
25350       return SDValue();
25351   } else if (CC == ISD::SETLT) {
25352     // (X <  0) ? A : 0
25353     // (X <  1) ? X : 0 <-- This is un-canonicalized signed min.
25354     if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2)))
25355       return SDValue();
25356   } else {
25357     return SDValue();
25358   }
25359 
25360   // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
25361   // constant.
25362   EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
25363   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
25364   if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
25365     unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
25366     if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
25367       SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
25368       SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
25369       AddToWorklist(Shift.getNode());
25370 
25371       if (XType.bitsGT(AType)) {
25372         Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
25373         AddToWorklist(Shift.getNode());
25374       }
25375 
25376       if (CC == ISD::SETGT)
25377         Shift = DAG.getNOT(DL, Shift, AType);
25378 
25379       return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
25380     }
25381   }
25382 
25383   unsigned ShCt = XType.getSizeInBits() - 1;
25384   if (TLI.shouldAvoidTransformToShift(XType, ShCt))
25385     return SDValue();
25386 
25387   SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
25388   SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
25389   AddToWorklist(Shift.getNode());
25390 
25391   if (XType.bitsGT(AType)) {
25392     Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
25393     AddToWorklist(Shift.getNode());
25394   }
25395 
25396   if (CC == ISD::SETGT)
25397     Shift = DAG.getNOT(DL, Shift, AType);
25398 
25399   return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
25400 }
25401 
25402 // Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
foldSelectOfBinops(SDNode * N)25403 SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
25404   SDValue N0 = N->getOperand(0);
25405   SDValue N1 = N->getOperand(1);
25406   SDValue N2 = N->getOperand(2);
25407   EVT VT = N->getValueType(0);
25408   SDLoc DL(N);
25409 
25410   unsigned BinOpc = N1.getOpcode();
25411   if (!TLI.isBinOp(BinOpc) || (N2.getOpcode() != BinOpc))
25412     return SDValue();
25413 
25414   // The use checks are intentionally on SDNode because we may be dealing
25415   // with opcodes that produce more than one SDValue.
25416   // TODO: Do we really need to check N0 (the condition operand of the select)?
25417   //       But removing that clause could cause an infinite loop...
25418   if (!N0->hasOneUse() || !N1->hasOneUse() || !N2->hasOneUse())
25419     return SDValue();
25420 
25421   // Binops may include opcodes that return multiple values, so all values
25422   // must be created/propagated from the newly created binops below.
25423   SDVTList OpVTs = N1->getVTList();
25424 
25425   // Fold select(cond, binop(x, y), binop(z, y))
25426   //  --> binop(select(cond, x, z), y)
25427   if (N1.getOperand(1) == N2.getOperand(1)) {
25428     SDValue NewSel =
25429         DAG.getSelect(DL, VT, N0, N1.getOperand(0), N2.getOperand(0));
25430     SDValue NewBinOp = DAG.getNode(BinOpc, DL, OpVTs, NewSel, N1.getOperand(1));
25431     NewBinOp->setFlags(N1->getFlags());
25432     NewBinOp->intersectFlagsWith(N2->getFlags());
25433     return NewBinOp;
25434   }
25435 
25436   // Fold select(cond, binop(x, y), binop(x, z))
25437   //  --> binop(x, select(cond, y, z))
25438   // Second op VT might be different (e.g. shift amount type)
25439   if (N1.getOperand(0) == N2.getOperand(0) &&
25440       VT == N1.getOperand(1).getValueType() &&
25441       VT == N2.getOperand(1).getValueType()) {
25442     SDValue NewSel =
25443         DAG.getSelect(DL, VT, N0, N1.getOperand(1), N2.getOperand(1));
25444     SDValue NewBinOp = DAG.getNode(BinOpc, DL, OpVTs, N1.getOperand(0), NewSel);
25445     NewBinOp->setFlags(N1->getFlags());
25446     NewBinOp->intersectFlagsWith(N2->getFlags());
25447     return NewBinOp;
25448   }
25449 
25450   // TODO: Handle isCommutativeBinOp patterns as well?
25451   return SDValue();
25452 }
25453 
25454 // Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
foldSignChangeInBitcast(SDNode * N)25455 SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
25456   SDValue N0 = N->getOperand(0);
25457   EVT VT = N->getValueType(0);
25458   bool IsFabs = N->getOpcode() == ISD::FABS;
25459   bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
25460 
25461   if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
25462     return SDValue();
25463 
25464   SDValue Int = N0.getOperand(0);
25465   EVT IntVT = Int.getValueType();
25466 
25467   // The operand to cast should be integer.
25468   if (!IntVT.isInteger() || IntVT.isVector())
25469     return SDValue();
25470 
25471   // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
25472   // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
25473   APInt SignMask;
25474   if (N0.getValueType().isVector()) {
25475     // For vector, create a sign mask (0x80...) or its inverse (for fabs,
25476     // 0x7f...) per element and splat it.
25477     SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
25478     if (IsFabs)
25479       SignMask = ~SignMask;
25480     SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
25481   } else {
25482     // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
25483     SignMask = APInt::getSignMask(IntVT.getSizeInBits());
25484     if (IsFabs)
25485       SignMask = ~SignMask;
25486   }
25487   SDLoc DL(N0);
25488   Int = DAG.getNode(IsFabs ? ISD::AND : ISD::XOR, DL, IntVT, Int,
25489                     DAG.getConstant(SignMask, DL, IntVT));
25490   AddToWorklist(Int.getNode());
25491   return DAG.getBitcast(VT, Int);
25492 }
25493 
25494 /// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
25495 /// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
25496 /// in it. This may be a win when the constant is not otherwise available
25497 /// because it replaces two constant pool loads with one.
convertSelectOfFPConstantsToLoadOffset(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)25498 SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
25499     const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
25500     ISD::CondCode CC) {
25501   if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType()))
25502     return SDValue();
25503 
25504   // If we are before legalize types, we want the other legalization to happen
25505   // first (for example, to avoid messing with soft float).
25506   auto *TV = dyn_cast<ConstantFPSDNode>(N2);
25507   auto *FV = dyn_cast<ConstantFPSDNode>(N3);
25508   EVT VT = N2.getValueType();
25509   if (!TV || !FV || !TLI.isTypeLegal(VT))
25510     return SDValue();
25511 
25512   // If a constant can be materialized without loads, this does not make sense.
25513   if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal ||
25514       TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0), ForCodeSize) ||
25515       TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0), ForCodeSize))
25516     return SDValue();
25517 
25518   // If both constants have multiple uses, then we won't need to do an extra
25519   // load. The values are likely around in registers for other users.
25520   if (!TV->hasOneUse() && !FV->hasOneUse())
25521     return SDValue();
25522 
25523   Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
25524                        const_cast<ConstantFP*>(TV->getConstantFPValue()) };
25525   Type *FPTy = Elts[0]->getType();
25526   const DataLayout &TD = DAG.getDataLayout();
25527 
25528   // Create a ConstantArray of the two constants.
25529   Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
25530   SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
25531                                       TD.getPrefTypeAlign(FPTy));
25532   Align Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlign();
25533 
25534   // Get offsets to the 0 and 1 elements of the array, so we can select between
25535   // them.
25536   SDValue Zero = DAG.getIntPtrConstant(0, DL);
25537   unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
25538   SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
25539   SDValue Cond =
25540       DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC);
25541   AddToWorklist(Cond.getNode());
25542   SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero);
25543   AddToWorklist(CstOffset.getNode());
25544   CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset);
25545   AddToWorklist(CPIdx.getNode());
25546   return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
25547                      MachinePointerInfo::getConstantPool(
25548                          DAG.getMachineFunction()), Alignment);
25549 }
25550 
25551 /// Simplify an expression of the form (N0 cond N1) ? N2 : N3
25552 /// where 'cond' is the comparison specified by CC.
SimplifySelectCC(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,bool NotExtCompare)25553 SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
25554                                       SDValue N2, SDValue N3, ISD::CondCode CC,
25555                                       bool NotExtCompare) {
25556   // (x ? y : y) -> y.
25557   if (N2 == N3) return N2;
25558 
25559   EVT CmpOpVT = N0.getValueType();
25560   EVT CmpResVT = getSetCCResultType(CmpOpVT);
25561   EVT VT = N2.getValueType();
25562   auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
25563   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
25564   auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode());
25565 
25566   // Determine if the condition we're dealing with is constant.
25567   if (SDValue SCC = DAG.FoldSetCC(CmpResVT, N0, N1, CC, DL)) {
25568     AddToWorklist(SCC.getNode());
25569     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC)) {
25570       // fold select_cc true, x, y -> x
25571       // fold select_cc false, x, y -> y
25572       return !(SCCC->isZero()) ? N2 : N3;
25573     }
25574   }
25575 
25576   if (SDValue V =
25577           convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
25578     return V;
25579 
25580   if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
25581     return V;
25582 
25583   // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl x)) A)
25584   // where y is has a single bit set.
25585   // A plaintext description would be, we can turn the SELECT_CC into an AND
25586   // when the condition can be materialized as an all-ones register.  Any
25587   // single bit-test can be materialized as an all-ones register with
25588   // shift-left and shift-right-arith.
25589   if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
25590       N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) {
25591     SDValue AndLHS = N0->getOperand(0);
25592     auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
25593     if (ConstAndRHS && ConstAndRHS->getAPIntValue().countPopulation() == 1) {
25594       // Shift the tested bit over the sign bit.
25595       const APInt &AndMask = ConstAndRHS->getAPIntValue();
25596       unsigned ShCt = AndMask.getBitWidth() - 1;
25597       if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
25598         SDValue ShlAmt =
25599           DAG.getConstant(AndMask.countLeadingZeros(), SDLoc(AndLHS),
25600                           getShiftAmountTy(AndLHS.getValueType()));
25601         SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
25602 
25603         // Now arithmetic right shift it all the way over, so the result is
25604         // either all-ones, or zero.
25605         SDValue ShrAmt =
25606           DAG.getConstant(ShCt, SDLoc(Shl),
25607                           getShiftAmountTy(Shl.getValueType()));
25608         SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
25609 
25610         return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
25611       }
25612     }
25613   }
25614 
25615   // fold select C, 16, 0 -> shl C, 4
25616   bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2();
25617   bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2();
25618 
25619   if ((Fold || Swap) &&
25620       TLI.getBooleanContents(CmpOpVT) ==
25621           TargetLowering::ZeroOrOneBooleanContent &&
25622       (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) {
25623 
25624     if (Swap) {
25625       CC = ISD::getSetCCInverse(CC, CmpOpVT);
25626       std::swap(N2C, N3C);
25627     }
25628 
25629     // If the caller doesn't want us to simplify this into a zext of a compare,
25630     // don't do it.
25631     if (NotExtCompare && N2C->isOne())
25632       return SDValue();
25633 
25634     SDValue Temp, SCC;
25635     // zext (setcc n0, n1)
25636     if (LegalTypes) {
25637       SCC = DAG.getSetCC(DL, CmpResVT, N0, N1, CC);
25638       if (VT.bitsLT(SCC.getValueType()))
25639         Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2), VT);
25640       else
25641         Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
25642     } else {
25643       SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
25644       Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
25645     }
25646 
25647     AddToWorklist(SCC.getNode());
25648     AddToWorklist(Temp.getNode());
25649 
25650     if (N2C->isOne())
25651       return Temp;
25652 
25653     unsigned ShCt = N2C->getAPIntValue().logBase2();
25654     if (TLI.shouldAvoidTransformToShift(VT, ShCt))
25655       return SDValue();
25656 
25657     // shl setcc result by log2 n2c
25658     return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
25659                        DAG.getConstant(ShCt, SDLoc(Temp),
25660                                        getShiftAmountTy(Temp.getValueType())));
25661   }
25662 
25663   // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
25664   // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
25665   // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
25666   // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
25667   // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
25668   // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
25669   // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
25670   // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
25671   if (N1C && N1C->isZero() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
25672     SDValue ValueOnZero = N2;
25673     SDValue Count = N3;
25674     // If the condition is NE instead of E, swap the operands.
25675     if (CC == ISD::SETNE)
25676       std::swap(ValueOnZero, Count);
25677     // Check if the value on zero is a constant equal to the bits in the type.
25678     if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) {
25679       if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
25680         // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
25681         // legal, combine to just cttz.
25682         if ((Count.getOpcode() == ISD::CTTZ ||
25683              Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
25684             N0 == Count.getOperand(0) &&
25685             (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT)))
25686           return DAG.getNode(ISD::CTTZ, DL, VT, N0);
25687         // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
25688         // legal, combine to just ctlz.
25689         if ((Count.getOpcode() == ISD::CTLZ ||
25690              Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
25691             N0 == Count.getOperand(0) &&
25692             (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT)))
25693           return DAG.getNode(ISD::CTLZ, DL, VT, N0);
25694       }
25695     }
25696   }
25697 
25698   // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C
25699   // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C
25700   if (!NotExtCompare && N1C && N2C && N3C &&
25701       N2C->getAPIntValue() == ~N3C->getAPIntValue() &&
25702       ((N1C->isAllOnes() && CC == ISD::SETGT) ||
25703        (N1C->isZero() && CC == ISD::SETLT)) &&
25704       !TLI.shouldAvoidTransformToShift(VT, CmpOpVT.getScalarSizeInBits() - 1)) {
25705     SDValue ASR = DAG.getNode(
25706         ISD::SRA, DL, CmpOpVT, N0,
25707         DAG.getConstant(CmpOpVT.getScalarSizeInBits() - 1, DL, CmpOpVT));
25708     return DAG.getNode(ISD::XOR, DL, VT, DAG.getSExtOrTrunc(ASR, DL, VT),
25709                        DAG.getSExtOrTrunc(CC == ISD::SETLT ? N3 : N2, DL, VT));
25710   }
25711 
25712   if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
25713     return S;
25714   if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG))
25715     return S;
25716 
25717   return SDValue();
25718 }
25719 
25720 /// This is a stub for TargetLowering::SimplifySetCC.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,bool foldBooleans)25721 SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
25722                                    ISD::CondCode Cond, const SDLoc &DL,
25723                                    bool foldBooleans) {
25724   TargetLowering::DAGCombinerInfo
25725     DagCombineInfo(DAG, Level, false, this);
25726   return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
25727 }
25728 
25729 /// Given an ISD::SDIV node expressing a divide by constant, return
25730 /// a DAG expression to select that will generate the same value by multiplying
25731 /// by a magic number.
25732 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N)25733 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
25734   // when optimising for minimum size, we don't want to expand a div to a mul
25735   // and a shift.
25736   if (DAG.getMachineFunction().getFunction().hasMinSize())
25737     return SDValue();
25738 
25739   SmallVector<SDNode *, 8> Built;
25740   if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, Built)) {
25741     for (SDNode *N : Built)
25742       AddToWorklist(N);
25743     return S;
25744   }
25745 
25746   return SDValue();
25747 }
25748 
25749 /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
25750 /// DAG expression that will generate the same value by right shifting.
BuildSDIVPow2(SDNode * N)25751 SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
25752   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
25753   if (!C)
25754     return SDValue();
25755 
25756   // Avoid division by zero.
25757   if (C->isZero())
25758     return SDValue();
25759 
25760   SmallVector<SDNode *, 8> Built;
25761   if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) {
25762     for (SDNode *N : Built)
25763       AddToWorklist(N);
25764     return S;
25765   }
25766 
25767   return SDValue();
25768 }
25769 
25770 /// Given an ISD::UDIV node expressing a divide by constant, return a DAG
25771 /// expression that will generate the same value by multiplying by a magic
25772 /// number.
25773 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N)25774 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
25775   // when optimising for minimum size, we don't want to expand a div to a mul
25776   // and a shift.
25777   if (DAG.getMachineFunction().getFunction().hasMinSize())
25778     return SDValue();
25779 
25780   SmallVector<SDNode *, 8> Built;
25781   if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, Built)) {
25782     for (SDNode *N : Built)
25783       AddToWorklist(N);
25784     return S;
25785   }
25786 
25787   return SDValue();
25788 }
25789 
25790 /// Given an ISD::SREM node expressing a remainder by constant power of 2,
25791 /// return a DAG expression that will generate the same value.
BuildSREMPow2(SDNode * N)25792 SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
25793   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
25794   if (!C)
25795     return SDValue();
25796 
25797   // Avoid division by zero.
25798   if (C->isZero())
25799     return SDValue();
25800 
25801   SmallVector<SDNode *, 8> Built;
25802   if (SDValue S = TLI.BuildSREMPow2(N, C->getAPIntValue(), DAG, Built)) {
25803     for (SDNode *N : Built)
25804       AddToWorklist(N);
25805     return S;
25806   }
25807 
25808   return SDValue();
25809 }
25810 
25811 /// Determines the LogBase2 value for a non-null input value using the
25812 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
BuildLogBase2(SDValue V,const SDLoc & DL)25813 SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) {
25814   EVT VT = V.getValueType();
25815   SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
25816   SDValue Base = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
25817   SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
25818   return LogBase2;
25819 }
25820 
25821 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
25822 /// For the reciprocal, we need to find the zero of the function:
25823 ///   F(X) = 1/X - A [which has a zero at X = 1/A]
25824 ///     =>
25825 ///   X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
25826 ///     does not require additional intermediate precision]
25827 /// For the last iteration, put numerator N into it to gain more precision:
25828 ///   Result = N X_i + X_i (N - N A X_i)
BuildDivEstimate(SDValue N,SDValue Op,SDNodeFlags Flags)25829 SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
25830                                       SDNodeFlags Flags) {
25831   if (LegalDAG)
25832     return SDValue();
25833 
25834   // TODO: Handle extended types?
25835   EVT VT = Op.getValueType();
25836   if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
25837       VT.getScalarType() != MVT::f64)
25838     return SDValue();
25839 
25840   // If estimates are explicitly disabled for this function, we're done.
25841   MachineFunction &MF = DAG.getMachineFunction();
25842   int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
25843   if (Enabled == TLI.ReciprocalEstimate::Disabled)
25844     return SDValue();
25845 
25846   // Estimates may be explicitly enabled for this type with a custom number of
25847   // refinement steps.
25848   int Iterations = TLI.getDivRefinementSteps(VT, MF);
25849   if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
25850     AddToWorklist(Est.getNode());
25851 
25852     SDLoc DL(Op);
25853     if (Iterations) {
25854       SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
25855 
25856       // Newton iterations: Est = Est + Est (N - Arg * Est)
25857       // If this is the last iteration, also multiply by the numerator.
25858       for (int i = 0; i < Iterations; ++i) {
25859         SDValue MulEst = Est;
25860 
25861         if (i == Iterations - 1) {
25862           MulEst = DAG.getNode(ISD::FMUL, DL, VT, N, Est, Flags);
25863           AddToWorklist(MulEst.getNode());
25864         }
25865 
25866         SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, MulEst, Flags);
25867         AddToWorklist(NewEst.getNode());
25868 
25869         NewEst = DAG.getNode(ISD::FSUB, DL, VT,
25870                              (i == Iterations - 1 ? N : FPOne), NewEst, Flags);
25871         AddToWorklist(NewEst.getNode());
25872 
25873         NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
25874         AddToWorklist(NewEst.getNode());
25875 
25876         Est = DAG.getNode(ISD::FADD, DL, VT, MulEst, NewEst, Flags);
25877         AddToWorklist(Est.getNode());
25878       }
25879     } else {
25880       // If no iterations are available, multiply with N.
25881       Est = DAG.getNode(ISD::FMUL, DL, VT, Est, N, Flags);
25882       AddToWorklist(Est.getNode());
25883     }
25884 
25885     return Est;
25886   }
25887 
25888   return SDValue();
25889 }
25890 
25891 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
25892 /// For the reciprocal sqrt, we need to find the zero of the function:
25893 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
25894 ///     =>
25895 ///   X_{i+1} = X_i (1.5 - A X_i^2 / 2)
25896 /// As a result, we precompute A/2 prior to the iteration loop.
buildSqrtNROneConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)25897 SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
25898                                          unsigned Iterations,
25899                                          SDNodeFlags Flags, bool Reciprocal) {
25900   EVT VT = Arg.getValueType();
25901   SDLoc DL(Arg);
25902   SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
25903 
25904   // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
25905   // this entire sequence requires only one FP constant.
25906   SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags);
25907   HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags);
25908 
25909   // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
25910   for (unsigned i = 0; i < Iterations; ++i) {
25911     SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
25912     NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags);
25913     NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags);
25914     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
25915   }
25916 
25917   // If non-reciprocal square root is requested, multiply the result by Arg.
25918   if (!Reciprocal)
25919     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
25920 
25921   return Est;
25922 }
25923 
25924 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
25925 /// For the reciprocal sqrt, we need to find the zero of the function:
25926 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
25927 ///     =>
25928 ///   X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
buildSqrtNRTwoConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)25929 SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
25930                                          unsigned Iterations,
25931                                          SDNodeFlags Flags, bool Reciprocal) {
25932   EVT VT = Arg.getValueType();
25933   SDLoc DL(Arg);
25934   SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
25935   SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
25936 
25937   // This routine must enter the loop below to work correctly
25938   // when (Reciprocal == false).
25939   assert(Iterations > 0);
25940 
25941   // Newton iterations for reciprocal square root:
25942   // E = (E * -0.5) * ((A * E) * E + -3.0)
25943   for (unsigned i = 0; i < Iterations; ++i) {
25944     SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
25945     SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
25946     SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
25947 
25948     // When calculating a square root at the last iteration build:
25949     // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
25950     // (notice a common subexpression)
25951     SDValue LHS;
25952     if (Reciprocal || (i + 1) < Iterations) {
25953       // RSQRT: LHS = (E * -0.5)
25954       LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
25955     } else {
25956       // SQRT: LHS = (A * E) * -0.5
25957       LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
25958     }
25959 
25960     Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
25961   }
25962 
25963   return Est;
25964 }
25965 
25966 /// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
25967 /// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
25968 /// Op can be zero.
buildSqrtEstimateImpl(SDValue Op,SDNodeFlags Flags,bool Reciprocal)25969 SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
25970                                            bool Reciprocal) {
25971   if (LegalDAG)
25972     return SDValue();
25973 
25974   // TODO: Handle extended types?
25975   EVT VT = Op.getValueType();
25976   if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
25977       VT.getScalarType() != MVT::f64)
25978     return SDValue();
25979 
25980   // If estimates are explicitly disabled for this function, we're done.
25981   MachineFunction &MF = DAG.getMachineFunction();
25982   int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
25983   if (Enabled == TLI.ReciprocalEstimate::Disabled)
25984     return SDValue();
25985 
25986   // Estimates may be explicitly enabled for this type with a custom number of
25987   // refinement steps.
25988   int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
25989 
25990   bool UseOneConstNR = false;
25991   if (SDValue Est =
25992       TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
25993                           Reciprocal)) {
25994     AddToWorklist(Est.getNode());
25995 
25996     if (Iterations)
25997       Est = UseOneConstNR
25998             ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
25999             : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
26000     if (!Reciprocal) {
26001       SDLoc DL(Op);
26002       // Try the target specific test first.
26003       SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
26004 
26005       // The estimate is now completely wrong if the input was exactly 0.0 or
26006       // possibly a denormal. Force the answer to 0.0 or value provided by
26007       // target for those cases.
26008       Est = DAG.getNode(
26009           Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
26010           Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
26011     }
26012     return Est;
26013   }
26014 
26015   return SDValue();
26016 }
26017 
buildRsqrtEstimate(SDValue Op,SDNodeFlags Flags)26018 SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
26019   return buildSqrtEstimateImpl(Op, Flags, true);
26020 }
26021 
buildSqrtEstimate(SDValue Op,SDNodeFlags Flags)26022 SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
26023   return buildSqrtEstimateImpl(Op, Flags, false);
26024 }
26025 
26026 /// Return true if there is any possibility that the two addresses overlap.
mayAlias(SDNode * Op0,SDNode * Op1) const26027 bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const {
26028 
26029   struct MemUseCharacteristics {
26030     bool IsVolatile;
26031     bool IsAtomic;
26032     SDValue BasePtr;
26033     int64_t Offset;
26034     std::optional<int64_t> NumBytes;
26035     MachineMemOperand *MMO;
26036   };
26037 
26038   auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
26039     if (const auto *LSN = dyn_cast<LSBaseSDNode>(N)) {
26040       int64_t Offset = 0;
26041       if (auto *C = dyn_cast<ConstantSDNode>(LSN->getOffset()))
26042         Offset = (LSN->getAddressingMode() == ISD::PRE_INC)
26043                      ? C->getSExtValue()
26044                      : (LSN->getAddressingMode() == ISD::PRE_DEC)
26045                            ? -1 * C->getSExtValue()
26046                            : 0;
26047       uint64_t Size =
26048           MemoryLocation::getSizeOrUnknown(LSN->getMemoryVT().getStoreSize());
26049       return {LSN->isVolatile(),
26050               LSN->isAtomic(),
26051               LSN->getBasePtr(),
26052               Offset /*base offset*/,
26053               std::optional<int64_t>(Size),
26054               LSN->getMemOperand()};
26055     }
26056     if (const auto *LN = cast<LifetimeSDNode>(N))
26057       return {false /*isVolatile*/,
26058               /*isAtomic*/ false,
26059               LN->getOperand(1),
26060               (LN->hasOffset()) ? LN->getOffset() : 0,
26061               (LN->hasOffset()) ? std::optional<int64_t>(LN->getSize())
26062                                 : std::optional<int64_t>(),
26063               (MachineMemOperand *)nullptr};
26064     // Default.
26065     return {false /*isvolatile*/,
26066             /*isAtomic*/ false,          SDValue(),
26067             (int64_t)0 /*offset*/,       std::optional<int64_t>() /*size*/,
26068             (MachineMemOperand *)nullptr};
26069   };
26070 
26071   MemUseCharacteristics MUC0 = getCharacteristics(Op0),
26072                         MUC1 = getCharacteristics(Op1);
26073 
26074   // If they are to the same address, then they must be aliases.
26075   if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
26076       MUC0.Offset == MUC1.Offset)
26077     return true;
26078 
26079   // If they are both volatile then they cannot be reordered.
26080   if (MUC0.IsVolatile && MUC1.IsVolatile)
26081     return true;
26082 
26083   // Be conservative about atomics for the moment
26084   // TODO: This is way overconservative for unordered atomics (see D66309)
26085   if (MUC0.IsAtomic && MUC1.IsAtomic)
26086     return true;
26087 
26088   if (MUC0.MMO && MUC1.MMO) {
26089     if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
26090         (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
26091       return false;
26092   }
26093 
26094   // Try to prove that there is aliasing, or that there is no aliasing. Either
26095   // way, we can return now. If nothing can be proved, proceed with more tests.
26096   bool IsAlias;
26097   if (BaseIndexOffset::computeAliasing(Op0, MUC0.NumBytes, Op1, MUC1.NumBytes,
26098                                        DAG, IsAlias))
26099     return IsAlias;
26100 
26101   // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
26102   // either are not known.
26103   if (!MUC0.MMO || !MUC1.MMO)
26104     return true;
26105 
26106   // If one operation reads from invariant memory, and the other may store, they
26107   // cannot alias. These should really be checking the equivalent of mayWrite,
26108   // but it only matters for memory nodes other than load /store.
26109   if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
26110       (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
26111     return false;
26112 
26113   // If we know required SrcValue1 and SrcValue2 have relatively large
26114   // alignment compared to the size and offset of the access, we may be able
26115   // to prove they do not alias. This check is conservative for now to catch
26116   // cases created by splitting vector types, it only works when the offsets are
26117   // multiples of the size of the data.
26118   int64_t SrcValOffset0 = MUC0.MMO->getOffset();
26119   int64_t SrcValOffset1 = MUC1.MMO->getOffset();
26120   Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
26121   Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
26122   auto &Size0 = MUC0.NumBytes;
26123   auto &Size1 = MUC1.NumBytes;
26124   if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
26125       Size0.has_value() && Size1.has_value() && *Size0 == *Size1 &&
26126       OrigAlignment0 > *Size0 && SrcValOffset0 % *Size0 == 0 &&
26127       SrcValOffset1 % *Size1 == 0) {
26128     int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
26129     int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
26130 
26131     // There is no overlap between these relatively aligned accesses of
26132     // similar size. Return no alias.
26133     if ((OffAlign0 + *Size0) <= OffAlign1 || (OffAlign1 + *Size1) <= OffAlign0)
26134       return false;
26135   }
26136 
26137   bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
26138                    ? CombinerGlobalAA
26139                    : DAG.getSubtarget().useAA();
26140 #ifndef NDEBUG
26141   if (CombinerAAOnlyFunc.getNumOccurrences() &&
26142       CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
26143     UseAA = false;
26144 #endif
26145 
26146   if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() && Size0 &&
26147       Size1) {
26148     // Use alias analysis information.
26149     int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
26150     int64_t Overlap0 = *Size0 + SrcValOffset0 - MinOffset;
26151     int64_t Overlap1 = *Size1 + SrcValOffset1 - MinOffset;
26152     if (AA->isNoAlias(
26153             MemoryLocation(MUC0.MMO->getValue(), Overlap0,
26154                            UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
26155             MemoryLocation(MUC1.MMO->getValue(), Overlap1,
26156                            UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
26157       return false;
26158   }
26159 
26160   // Otherwise we have to assume they alias.
26161   return true;
26162 }
26163 
26164 /// Walk up chain skipping non-aliasing memory nodes,
26165 /// looking for aliasing nodes and adding them to the Aliases vector.
GatherAllAliases(SDNode * N,SDValue OriginalChain,SmallVectorImpl<SDValue> & Aliases)26166 void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
26167                                    SmallVectorImpl<SDValue> &Aliases) {
26168   SmallVector<SDValue, 8> Chains;     // List of chains to visit.
26169   SmallPtrSet<SDNode *, 16> Visited;  // Visited node set.
26170 
26171   // Get alias information for node.
26172   // TODO: relax aliasing for unordered atomics (see D66309)
26173   const bool IsLoad = isa<LoadSDNode>(N) && cast<LoadSDNode>(N)->isSimple();
26174 
26175   // Starting off.
26176   Chains.push_back(OriginalChain);
26177   unsigned Depth = 0;
26178 
26179   // Attempt to improve chain by a single step
26180   auto ImproveChain = [&](SDValue &C) -> bool {
26181     switch (C.getOpcode()) {
26182     case ISD::EntryToken:
26183       // No need to mark EntryToken.
26184       C = SDValue();
26185       return true;
26186     case ISD::LOAD:
26187     case ISD::STORE: {
26188       // Get alias information for C.
26189       // TODO: Relax aliasing for unordered atomics (see D66309)
26190       bool IsOpLoad = isa<LoadSDNode>(C.getNode()) &&
26191                       cast<LSBaseSDNode>(C.getNode())->isSimple();
26192       if ((IsLoad && IsOpLoad) || !mayAlias(N, C.getNode())) {
26193         // Look further up the chain.
26194         C = C.getOperand(0);
26195         return true;
26196       }
26197       // Alias, so stop here.
26198       return false;
26199     }
26200 
26201     case ISD::CopyFromReg:
26202       // Always forward past past CopyFromReg.
26203       C = C.getOperand(0);
26204       return true;
26205 
26206     case ISD::LIFETIME_START:
26207     case ISD::LIFETIME_END: {
26208       // We can forward past any lifetime start/end that can be proven not to
26209       // alias the memory access.
26210       if (!mayAlias(N, C.getNode())) {
26211         // Look further up the chain.
26212         C = C.getOperand(0);
26213         return true;
26214       }
26215       return false;
26216     }
26217     default:
26218       return false;
26219     }
26220   };
26221 
26222   // Look at each chain and determine if it is an alias.  If so, add it to the
26223   // aliases list.  If not, then continue up the chain looking for the next
26224   // candidate.
26225   while (!Chains.empty()) {
26226     SDValue Chain = Chains.pop_back_val();
26227 
26228     // Don't bother if we've seen Chain before.
26229     if (!Visited.insert(Chain.getNode()).second)
26230       continue;
26231 
26232     // For TokenFactor nodes, look at each operand and only continue up the
26233     // chain until we reach the depth limit.
26234     //
26235     // FIXME: The depth check could be made to return the last non-aliasing
26236     // chain we found before we hit a tokenfactor rather than the original
26237     // chain.
26238     if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
26239       Aliases.clear();
26240       Aliases.push_back(OriginalChain);
26241       return;
26242     }
26243 
26244     if (Chain.getOpcode() == ISD::TokenFactor) {
26245       // We have to check each of the operands of the token factor for "small"
26246       // token factors, so we queue them up.  Adding the operands to the queue
26247       // (stack) in reverse order maintains the original order and increases the
26248       // likelihood that getNode will find a matching token factor (CSE.)
26249       if (Chain.getNumOperands() > 16) {
26250         Aliases.push_back(Chain);
26251         continue;
26252       }
26253       for (unsigned n = Chain.getNumOperands(); n;)
26254         Chains.push_back(Chain.getOperand(--n));
26255       ++Depth;
26256       continue;
26257     }
26258     // Everything else
26259     if (ImproveChain(Chain)) {
26260       // Updated Chain Found, Consider new chain if one exists.
26261       if (Chain.getNode())
26262         Chains.push_back(Chain);
26263       ++Depth;
26264       continue;
26265     }
26266     // No Improved Chain Possible, treat as Alias.
26267     Aliases.push_back(Chain);
26268   }
26269 }
26270 
26271 /// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
26272 /// (aliasing node.)
FindBetterChain(SDNode * N,SDValue OldChain)26273 SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
26274   if (OptLevel == CodeGenOpt::None)
26275     return OldChain;
26276 
26277   // Ops for replacing token factor.
26278   SmallVector<SDValue, 8> Aliases;
26279 
26280   // Accumulate all the aliases to this node.
26281   GatherAllAliases(N, OldChain, Aliases);
26282 
26283   // If no operands then chain to entry token.
26284   if (Aliases.size() == 0)
26285     return DAG.getEntryNode();
26286 
26287   // If a single operand then chain to it.  We don't need to revisit it.
26288   if (Aliases.size() == 1)
26289     return Aliases[0];
26290 
26291   // Construct a custom tailored token factor.
26292   return DAG.getTokenFactor(SDLoc(N), Aliases);
26293 }
26294 
26295 // This function tries to collect a bunch of potentially interesting
26296 // nodes to improve the chains of, all at once. This might seem
26297 // redundant, as this function gets called when visiting every store
26298 // node, so why not let the work be done on each store as it's visited?
26299 //
26300 // I believe this is mainly important because mergeConsecutiveStores
26301 // is unable to deal with merging stores of different sizes, so unless
26302 // we improve the chains of all the potential candidates up-front
26303 // before running mergeConsecutiveStores, it might only see some of
26304 // the nodes that will eventually be candidates, and then not be able
26305 // to go from a partially-merged state to the desired final
26306 // fully-merged state.
26307 
parallelizeChainedStores(StoreSDNode * St)26308 bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
26309   SmallVector<StoreSDNode *, 8> ChainedStores;
26310   StoreSDNode *STChain = St;
26311   // Intervals records which offsets from BaseIndex have been covered. In
26312   // the common case, every store writes to the immediately previous address
26313   // space and thus merged with the previous interval at insertion time.
26314 
26315   using IMap = llvm::IntervalMap<int64_t, std::monostate, 8,
26316                                  IntervalMapHalfOpenInfo<int64_t>>;
26317   IMap::Allocator A;
26318   IMap Intervals(A);
26319 
26320   // This holds the base pointer, index, and the offset in bytes from the base
26321   // pointer.
26322   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
26323 
26324   // We must have a base and an offset.
26325   if (!BasePtr.getBase().getNode())
26326     return false;
26327 
26328   // Do not handle stores to undef base pointers.
26329   if (BasePtr.getBase().isUndef())
26330     return false;
26331 
26332   // Do not handle stores to opaque types
26333   if (St->getMemoryVT().isZeroSized())
26334     return false;
26335 
26336   // BaseIndexOffset assumes that offsets are fixed-size, which
26337   // is not valid for scalable vectors where the offsets are
26338   // scaled by `vscale`, so bail out early.
26339   if (St->getMemoryVT().isScalableVector())
26340     return false;
26341 
26342   // Add ST's interval.
26343   Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8,
26344                    std::monostate{});
26345 
26346   while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
26347     if (Chain->getMemoryVT().isScalableVector())
26348       return false;
26349 
26350     // If the chain has more than one use, then we can't reorder the mem ops.
26351     if (!SDValue(Chain, 0)->hasOneUse())
26352       break;
26353     // TODO: Relax for unordered atomics (see D66309)
26354     if (!Chain->isSimple() || Chain->isIndexed())
26355       break;
26356 
26357     // Find the base pointer and offset for this memory node.
26358     const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
26359     // Check that the base pointer is the same as the original one.
26360     int64_t Offset;
26361     if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
26362       break;
26363     int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
26364     // Make sure we don't overlap with other intervals by checking the ones to
26365     // the left or right before inserting.
26366     auto I = Intervals.find(Offset);
26367     // If there's a next interval, we should end before it.
26368     if (I != Intervals.end() && I.start() < (Offset + Length))
26369       break;
26370     // If there's a previous interval, we should start after it.
26371     if (I != Intervals.begin() && (--I).stop() <= Offset)
26372       break;
26373     Intervals.insert(Offset, Offset + Length, std::monostate{});
26374 
26375     ChainedStores.push_back(Chain);
26376     STChain = Chain;
26377   }
26378 
26379   // If we didn't find a chained store, exit.
26380   if (ChainedStores.size() == 0)
26381     return false;
26382 
26383   // Improve all chained stores (St and ChainedStores members) starting from
26384   // where the store chain ended and return single TokenFactor.
26385   SDValue NewChain = STChain->getChain();
26386   SmallVector<SDValue, 8> TFOps;
26387   for (unsigned I = ChainedStores.size(); I;) {
26388     StoreSDNode *S = ChainedStores[--I];
26389     SDValue BetterChain = FindBetterChain(S, NewChain);
26390     S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
26391         S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
26392     TFOps.push_back(SDValue(S, 0));
26393     ChainedStores[I] = S;
26394   }
26395 
26396   // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
26397   SDValue BetterChain = FindBetterChain(St, NewChain);
26398   SDValue NewST;
26399   if (St->isTruncatingStore())
26400     NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
26401                               St->getBasePtr(), St->getMemoryVT(),
26402                               St->getMemOperand());
26403   else
26404     NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
26405                          St->getBasePtr(), St->getMemOperand());
26406 
26407   TFOps.push_back(NewST);
26408 
26409   // If we improved every element of TFOps, then we've lost the dependence on
26410   // NewChain to successors of St and we need to add it back to TFOps. Do so at
26411   // the beginning to keep relative order consistent with FindBetterChains.
26412   auto hasImprovedChain = [&](SDValue ST) -> bool {
26413     return ST->getOperand(0) != NewChain;
26414   };
26415   bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
26416   if (AddNewChain)
26417     TFOps.insert(TFOps.begin(), NewChain);
26418 
26419   SDValue TF = DAG.getTokenFactor(SDLoc(STChain), TFOps);
26420   CombineTo(St, TF);
26421 
26422   // Add TF and its operands to the worklist.
26423   AddToWorklist(TF.getNode());
26424   for (const SDValue &Op : TF->ops())
26425     AddToWorklist(Op.getNode());
26426   AddToWorklist(STChain);
26427   return true;
26428 }
26429 
findBetterNeighborChains(StoreSDNode * St)26430 bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
26431   if (OptLevel == CodeGenOpt::None)
26432     return false;
26433 
26434   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
26435 
26436   // We must have a base and an offset.
26437   if (!BasePtr.getBase().getNode())
26438     return false;
26439 
26440   // Do not handle stores to undef base pointers.
26441   if (BasePtr.getBase().isUndef())
26442     return false;
26443 
26444   // Directly improve a chain of disjoint stores starting at St.
26445   if (parallelizeChainedStores(St))
26446     return true;
26447 
26448   // Improve St's Chain..
26449   SDValue BetterChain = FindBetterChain(St, St->getChain());
26450   if (St->getChain() != BetterChain) {
26451     replaceStoreChain(St, BetterChain);
26452     return true;
26453   }
26454   return false;
26455 }
26456 
26457 /// This is the entry point for the file.
Combine(CombineLevel Level,AliasAnalysis * AA,CodeGenOpt::Level OptLevel)26458 void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
26459                            CodeGenOpt::Level OptLevel) {
26460   /// This is the main entry point to this class.
26461   DAGCombiner(*this, AA, OptLevel).Run(Level);
26462 }
26463