1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "SPIRVLegalizerInfo.h"
14 #include "SPIRV.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/MachineInstr.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22
23 using namespace llvm;
24 using namespace llvm::LegalizeActions;
25 using namespace llvm::LegalityPredicates;
26
27 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28 TargetOpcode::G_ADD,
29 TargetOpcode::G_FADD,
30 TargetOpcode::G_SUB,
31 TargetOpcode::G_FSUB,
32 TargetOpcode::G_MUL,
33 TargetOpcode::G_FMUL,
34 TargetOpcode::G_SDIV,
35 TargetOpcode::G_UDIV,
36 TargetOpcode::G_FDIV,
37 TargetOpcode::G_SREM,
38 TargetOpcode::G_UREM,
39 TargetOpcode::G_FREM,
40 TargetOpcode::G_FNEG,
41 TargetOpcode::G_CONSTANT,
42 TargetOpcode::G_FCONSTANT,
43 TargetOpcode::G_AND,
44 TargetOpcode::G_OR,
45 TargetOpcode::G_XOR,
46 TargetOpcode::G_SHL,
47 TargetOpcode::G_ASHR,
48 TargetOpcode::G_LSHR,
49 TargetOpcode::G_SELECT,
50 TargetOpcode::G_EXTRACT_VECTOR_ELT,
51 };
52
isTypeFoldingSupported(unsigned Opcode)53 bool isTypeFoldingSupported(unsigned Opcode) {
54 return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55 }
56
SPIRVLegalizerInfo(const SPIRVSubtarget & ST)57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
58 using namespace TargetOpcode;
59
60 this->ST = &ST;
61 GR = ST.getSPIRVGlobalRegistry();
62
63 const LLT s1 = LLT::scalar(1);
64 const LLT s8 = LLT::scalar(8);
65 const LLT s16 = LLT::scalar(16);
66 const LLT s32 = LLT::scalar(32);
67 const LLT s64 = LLT::scalar(64);
68
69 const LLT v16s64 = LLT::fixed_vector(16, 64);
70 const LLT v16s32 = LLT::fixed_vector(16, 32);
71 const LLT v16s16 = LLT::fixed_vector(16, 16);
72 const LLT v16s8 = LLT::fixed_vector(16, 8);
73 const LLT v16s1 = LLT::fixed_vector(16, 1);
74
75 const LLT v8s64 = LLT::fixed_vector(8, 64);
76 const LLT v8s32 = LLT::fixed_vector(8, 32);
77 const LLT v8s16 = LLT::fixed_vector(8, 16);
78 const LLT v8s8 = LLT::fixed_vector(8, 8);
79 const LLT v8s1 = LLT::fixed_vector(8, 1);
80
81 const LLT v4s64 = LLT::fixed_vector(4, 64);
82 const LLT v4s32 = LLT::fixed_vector(4, 32);
83 const LLT v4s16 = LLT::fixed_vector(4, 16);
84 const LLT v4s8 = LLT::fixed_vector(4, 8);
85 const LLT v4s1 = LLT::fixed_vector(4, 1);
86
87 const LLT v3s64 = LLT::fixed_vector(3, 64);
88 const LLT v3s32 = LLT::fixed_vector(3, 32);
89 const LLT v3s16 = LLT::fixed_vector(3, 16);
90 const LLT v3s8 = LLT::fixed_vector(3, 8);
91 const LLT v3s1 = LLT::fixed_vector(3, 1);
92
93 const LLT v2s64 = LLT::fixed_vector(2, 64);
94 const LLT v2s32 = LLT::fixed_vector(2, 32);
95 const LLT v2s16 = LLT::fixed_vector(2, 16);
96 const LLT v2s8 = LLT::fixed_vector(2, 8);
97 const LLT v2s1 = LLT::fixed_vector(2, 1);
98
99 const unsigned PSize = ST.getPointerSize();
100 const LLT p0 = LLT::pointer(0, PSize); // Function
101 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104 const LLT p4 = LLT::pointer(4, PSize); // Generic
105 const LLT p5 = LLT::pointer(5, PSize); // Input
106
107 // TODO: remove copy-pasting here by using concatenation in some way.
108 auto allPtrsScalarsAndVectors = {
109 p0, p1, p2, p3, p4, p5, s1, s8, s16,
110 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
111 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1,
112 v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
113
114 auto allScalarsAndVectors = {
115 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
116 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
117 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
118
119 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
120 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
121 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
122 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
123
124 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
125
126 auto allIntScalars = {s8, s16, s32, s64};
127
128 auto allFloatScalarsAndVectors = {
129 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
130 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
131
132 auto allFloatAndIntScalars = allIntScalars;
133
134 auto allPtrs = {p0, p1, p2, p3, p4, p5};
135 auto allWritablePtrs = {p0, p1, p3, p4};
136
137 for (auto Opc : TypeFoldingSupportingOpcs)
138 getActionDefinitionsBuilder(Opc).custom();
139
140 getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
141
142 // TODO: add proper rules for vectors legalization.
143 getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
144
145 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
146 .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
147
148 getActionDefinitionsBuilder(G_MEMSET).legalIf(
149 all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
150
151 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
152 .legalForCartesianProduct(allPtrs, allPtrs);
153
154 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
155
156 getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
157
158 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
159
160 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
161 .legalForCartesianProduct(allIntScalarsAndVectors,
162 allFloatScalarsAndVectors);
163
164 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
165 .legalForCartesianProduct(allFloatScalarsAndVectors,
166 allScalarsAndVectors);
167
168 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
169 .legalFor(allIntScalarsAndVectors);
170
171 getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
172 allIntScalarsAndVectors, allIntScalarsAndVectors);
173
174 getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
175
176 getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
177 typeInSet(0, allPtrsScalarsAndVectors),
178 typeInSet(1, allPtrsScalarsAndVectors),
179 LegalityPredicate(([=](const LegalityQuery &Query) {
180 return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
181 }))));
182
183 getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal();
184
185 getActionDefinitionsBuilder(G_INTTOPTR)
186 .legalForCartesianProduct(allPtrs, allIntScalars);
187 getActionDefinitionsBuilder(G_PTRTOINT)
188 .legalForCartesianProduct(allIntScalars, allPtrs);
189 getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
190 allPtrs, allIntScalars);
191
192 // ST.canDirectlyComparePointers() for pointer args is supported in
193 // legalizeCustom().
194 getActionDefinitionsBuilder(G_ICMP).customIf(
195 all(typeInSet(0, allBoolScalarsAndVectors),
196 typeInSet(1, allPtrsScalarsAndVectors)));
197
198 getActionDefinitionsBuilder(G_FCMP).legalIf(
199 all(typeInSet(0, allBoolScalarsAndVectors),
200 typeInSet(1, allFloatScalarsAndVectors)));
201
202 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
203 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
204 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
205 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
206 .legalForCartesianProduct(allIntScalars, allWritablePtrs);
207
208 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
209 .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
210
211 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
212 // TODO: add proper legalization rules.
213 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
214
215 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
216 .alwaysLegal();
217
218 // Extensions.
219 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
220 .legalForCartesianProduct(allScalarsAndVectors);
221
222 // FP conversions.
223 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
224 .legalForCartesianProduct(allFloatScalarsAndVectors);
225
226 // Pointer-handling.
227 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
228
229 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
230 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
231
232 getActionDefinitionsBuilder({G_FPOW,
233 G_FEXP,
234 G_FEXP2,
235 G_FLOG,
236 G_FLOG2,
237 G_FABS,
238 G_FMINNUM,
239 G_FMAXNUM,
240 G_FCEIL,
241 G_FCOS,
242 G_FSIN,
243 G_FSQRT,
244 G_FFLOOR,
245 G_FRINT,
246 G_FNEARBYINT,
247 G_INTRINSIC_ROUND,
248 G_INTRINSIC_TRUNC,
249 G_FMINIMUM,
250 G_FMAXIMUM,
251 G_INTRINSIC_ROUNDEVEN})
252 .legalFor(allFloatScalarsAndVectors);
253
254 getActionDefinitionsBuilder(G_FCOPYSIGN)
255 .legalForCartesianProduct(allFloatScalarsAndVectors,
256 allFloatScalarsAndVectors);
257
258 getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
259 allFloatScalarsAndVectors, allIntScalarsAndVectors);
260
261 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
262 getActionDefinitionsBuilder(G_FLOG10).legalFor(allFloatScalarsAndVectors);
263
264 getActionDefinitionsBuilder(
265 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
266 .legalForCartesianProduct(allIntScalarsAndVectors,
267 allIntScalarsAndVectors);
268
269 // Struct return types become a single scalar, so cannot easily legalize.
270 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
271 }
272
273 getLegacyLegalizerInfo().computeTables();
274 verify(*ST.getInstrInfo());
275 }
276
convertPtrToInt(Register Reg,LLT ConvTy,SPIRVType * SpirvType,LegalizerHelper & Helper,MachineRegisterInfo & MRI,SPIRVGlobalRegistry * GR)277 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
278 LegalizerHelper &Helper,
279 MachineRegisterInfo &MRI,
280 SPIRVGlobalRegistry *GR) {
281 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
282 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
283 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
284 .addDef(ConvReg)
285 .addUse(Reg);
286 return ConvReg;
287 }
288
legalizeCustom(LegalizerHelper & Helper,MachineInstr & MI) const289 bool SPIRVLegalizerInfo::legalizeCustom(LegalizerHelper &Helper,
290 MachineInstr &MI) const {
291 auto Opc = MI.getOpcode();
292 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
293 if (!isTypeFoldingSupported(Opc)) {
294 assert(Opc == TargetOpcode::G_ICMP);
295 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
296 auto &Op0 = MI.getOperand(2);
297 auto &Op1 = MI.getOperand(3);
298 Register Reg0 = Op0.getReg();
299 Register Reg1 = Op1.getReg();
300 CmpInst::Predicate Cond =
301 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
302 if ((!ST->canDirectlyComparePointers() ||
303 (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
304 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
305 LLT ConvT = LLT::scalar(ST->getPointerSize());
306 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
307 ST->getPointerSize());
308 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
309 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
310 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
311 }
312 return true;
313 }
314 // TODO: implement legalization for other opcodes.
315 return true;
316 }
317