1 // Copyright 2019 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "SpirvShader.hpp"
16 #include "SpirvShaderDebug.hpp"
17
18 #include "ShaderCore.hpp"
19
20 #include <spirv/unified1/spirv.hpp>
21
22 #include <limits>
23
24 namespace sw {
25
EmitVectorTimesScalar(Spirv::InsnIterator insn)26 void SpirvEmitter::EmitVectorTimesScalar(Spirv::InsnIterator insn)
27 {
28 auto &type = shader.getType(insn.resultTypeId());
29 auto &dst = createIntermediate(insn.resultId(), type.componentCount);
30 auto lhs = Operand(shader, *this, insn.word(3));
31 auto rhs = Operand(shader, *this, insn.word(4));
32
33 for(auto i = 0u; i < type.componentCount; i++)
34 {
35 dst.move(i, lhs.Float(i) * rhs.Float(0));
36 }
37 }
38
EmitMatrixTimesVector(Spirv::InsnIterator insn)39 void SpirvEmitter::EmitMatrixTimesVector(Spirv::InsnIterator insn)
40 {
41 auto &type = shader.getType(insn.resultTypeId());
42 auto &dst = createIntermediate(insn.resultId(), type.componentCount);
43 auto lhs = Operand(shader, *this, insn.word(3));
44 auto rhs = Operand(shader, *this, insn.word(4));
45
46 for(auto i = 0u; i < type.componentCount; i++)
47 {
48 SIMD::Float v = lhs.Float(i) * rhs.Float(0);
49 for(auto j = 1u; j < rhs.componentCount; j++)
50 {
51 v = MulAdd(lhs.Float(i + type.componentCount * j), rhs.Float(j), v);
52 }
53 dst.move(i, v);
54 }
55 }
56
EmitVectorTimesMatrix(Spirv::InsnIterator insn)57 void SpirvEmitter::EmitVectorTimesMatrix(Spirv::InsnIterator insn)
58 {
59 auto &type = shader.getType(insn.resultTypeId());
60 auto &dst = createIntermediate(insn.resultId(), type.componentCount);
61 auto lhs = Operand(shader, *this, insn.word(3));
62 auto rhs = Operand(shader, *this, insn.word(4));
63
64 for(auto i = 0u; i < type.componentCount; i++)
65 {
66 SIMD::Float v = lhs.Float(0) * rhs.Float(i * lhs.componentCount);
67 for(auto j = 1u; j < lhs.componentCount; j++)
68 {
69 v = MulAdd(lhs.Float(j), rhs.Float(i * lhs.componentCount + j), v);
70 }
71 dst.move(i, v);
72 }
73 }
74
EmitMatrixTimesMatrix(Spirv::InsnIterator insn)75 void SpirvEmitter::EmitMatrixTimesMatrix(Spirv::InsnIterator insn)
76 {
77 auto &type = shader.getType(insn.resultTypeId());
78 auto &dst = createIntermediate(insn.resultId(), type.componentCount);
79 auto lhs = Operand(shader, *this, insn.word(3));
80 auto rhs = Operand(shader, *this, insn.word(4));
81
82 auto numColumns = type.definition.word(3);
83 auto numRows = shader.getType(type.definition.word(2)).definition.word(3);
84 auto numAdds = shader.getObjectType(insn.word(3)).definition.word(3);
85
86 for(auto row = 0u; row < numRows; row++)
87 {
88 for(auto col = 0u; col < numColumns; col++)
89 {
90 SIMD::Float v = lhs.Float(row) * rhs.Float(col * numAdds);
91 for(auto i = 1u; i < numAdds; i++)
92 {
93 v = MulAdd(lhs.Float(i * numRows + row), rhs.Float(col * numAdds + i), v);
94 }
95 dst.move(numRows * col + row, v);
96 }
97 }
98 }
99
EmitOuterProduct(Spirv::InsnIterator insn)100 void SpirvEmitter::EmitOuterProduct(Spirv::InsnIterator insn)
101 {
102 auto &type = shader.getType(insn.resultTypeId());
103 auto &dst = createIntermediate(insn.resultId(), type.componentCount);
104 auto lhs = Operand(shader, *this, insn.word(3));
105 auto rhs = Operand(shader, *this, insn.word(4));
106
107 auto numRows = lhs.componentCount;
108 auto numCols = rhs.componentCount;
109
110 for(auto col = 0u; col < numCols; col++)
111 {
112 for(auto row = 0u; row < numRows; row++)
113 {
114 dst.move(col * numRows + row, lhs.Float(row) * rhs.Float(col));
115 }
116 }
117 }
118
EmitTranspose(Spirv::InsnIterator insn)119 void SpirvEmitter::EmitTranspose(Spirv::InsnIterator insn)
120 {
121 auto &type = shader.getType(insn.resultTypeId());
122 auto &dst = createIntermediate(insn.resultId(), type.componentCount);
123 auto mat = Operand(shader, *this, insn.word(3));
124
125 auto numCols = type.definition.word(3);
126 auto numRows = shader.getType(type.definition.word(2)).componentCount;
127
128 for(auto col = 0u; col < numCols; col++)
129 {
130 for(auto row = 0u; row < numRows; row++)
131 {
132 dst.move(col * numRows + row, mat.Float(row * numCols + col));
133 }
134 }
135 }
136
EmitBitcastPointer(Spirv::Object::ID resultID,Operand & src)137 void SpirvEmitter::EmitBitcastPointer(Spirv::Object::ID resultID, Operand &src)
138 {
139 if(src.isPointer()) // Pointer -> Integer bits
140 {
141 if(sizeof(void *) == 4) // 32-bit pointers
142 {
143 SIMD::UInt bits;
144 src.Pointer().castTo(bits);
145
146 auto &dst = createIntermediate(resultID, 1);
147 dst.move(0, bits);
148 }
149 else // 64-bit pointers
150 {
151 ASSERT(sizeof(void *) == 8);
152 // Casting a 64 bit pointer into 2 32bit integers
153 auto &ptr = src.Pointer();
154 SIMD::UInt lowerBits, upperBits;
155 ptr.castTo(lowerBits, upperBits);
156
157 auto &dst = createIntermediate(resultID, 2);
158 dst.move(0, lowerBits);
159 dst.move(1, upperBits);
160 }
161 }
162 else // Integer bits -> Pointer
163 {
164 if(sizeof(void *) == 4) // 32-bit pointers
165 {
166 createPointer(resultID, SIMD::Pointer(src.UInt(0)));
167 }
168 else // 64-bit pointers
169 {
170 ASSERT(sizeof(void *) == 8);
171 // Casting two 32-bit integers into a 64-bit pointer
172 createPointer(resultID, SIMD::Pointer(src.UInt(0), src.UInt(1)));
173 }
174 }
175 }
176
EmitUnaryOp(Spirv::InsnIterator insn)177 void SpirvEmitter::EmitUnaryOp(Spirv::InsnIterator insn)
178 {
179 auto &type = shader.getType(insn.resultTypeId());
180 auto src = Operand(shader, *this, insn.word(3));
181
182 bool dstIsPointer = shader.getObject(insn.resultId()).kind == Spirv::Object::Kind::Pointer;
183 bool srcIsPointer = src.isPointer();
184 if(srcIsPointer || dstIsPointer)
185 {
186 ASSERT(insn.opcode() == spv::OpBitcast);
187 ASSERT(srcIsPointer || (type.componentCount == 1)); // When the ouput is a pointer, it's a single pointer
188
189 return EmitBitcastPointer(insn.resultId(), src);
190 }
191
192 auto &dst = createIntermediate(insn.resultId(), type.componentCount);
193
194 for(auto i = 0u; i < type.componentCount; i++)
195 {
196 switch(insn.opcode())
197 {
198 case spv::OpNot:
199 case spv::OpLogicalNot: // logical not == bitwise not due to all-bits boolean representation
200 dst.move(i, ~src.UInt(i));
201 break;
202 case spv::OpBitFieldInsert:
203 {
204 auto insert = Operand(shader, *this, insn.word(4)).UInt(i);
205 auto offset = Operand(shader, *this, insn.word(5)).UInt(0);
206 auto count = Operand(shader, *this, insn.word(6)).UInt(0);
207 auto one = SIMD::UInt(1);
208 auto v = src.UInt(i);
209 auto mask = Bitmask32(offset + count) ^ Bitmask32(offset);
210 dst.move(i, (v & ~mask) | ((insert << offset) & mask));
211 }
212 break;
213 case spv::OpBitFieldSExtract:
214 case spv::OpBitFieldUExtract:
215 {
216 auto offset = Operand(shader, *this, insn.word(4)).UInt(0);
217 auto count = Operand(shader, *this, insn.word(5)).UInt(0);
218 auto one = SIMD::UInt(1);
219 auto v = src.UInt(i);
220 SIMD::UInt out = (v >> offset) & Bitmask32(count);
221 if(insn.opcode() == spv::OpBitFieldSExtract)
222 {
223 auto sign = out & NthBit32(count - one);
224 auto sext = ~(sign - one);
225 out |= sext;
226 }
227 dst.move(i, out);
228 }
229 break;
230 case spv::OpBitReverse:
231 {
232 // TODO: Add an intrinsic to reactor. Even if there isn't a
233 // single vector instruction, there may be target-dependent
234 // ways to make this faster.
235 // https://graphics.stanford.edu/~seander/bithacks.html#ReverseParallel
236 SIMD::UInt v = src.UInt(i);
237 v = ((v >> 1) & SIMD::UInt(0x55555555)) | ((v & SIMD::UInt(0x55555555)) << 1);
238 v = ((v >> 2) & SIMD::UInt(0x33333333)) | ((v & SIMD::UInt(0x33333333)) << 2);
239 v = ((v >> 4) & SIMD::UInt(0x0F0F0F0F)) | ((v & SIMD::UInt(0x0F0F0F0F)) << 4);
240 v = ((v >> 8) & SIMD::UInt(0x00FF00FF)) | ((v & SIMD::UInt(0x00FF00FF)) << 8);
241 v = (v >> 16) | (v << 16);
242 dst.move(i, v);
243 }
244 break;
245 case spv::OpBitCount:
246 dst.move(i, CountBits(src.UInt(i)));
247 break;
248 case spv::OpSNegate:
249 dst.move(i, -src.Int(i));
250 break;
251 case spv::OpFNegate:
252 dst.move(i, -src.Float(i));
253 break;
254 case spv::OpConvertFToU:
255 dst.move(i, SIMD::UInt(src.Float(i)));
256 break;
257 case spv::OpConvertFToS:
258 dst.move(i, SIMD::Int(src.Float(i)));
259 break;
260 case spv::OpConvertSToF:
261 dst.move(i, SIMD::Float(src.Int(i)));
262 break;
263 case spv::OpConvertUToF:
264 dst.move(i, SIMD::Float(src.UInt(i)));
265 break;
266 case spv::OpBitcast:
267 dst.move(i, src.Float(i));
268 break;
269 case spv::OpIsInf:
270 dst.move(i, IsInf(src.Float(i)));
271 break;
272 case spv::OpIsNan:
273 dst.move(i, IsNan(src.Float(i)));
274 break;
275 case spv::OpDPdx:
276 case spv::OpDPdxCoarse:
277 // Derivative instructions: FS invocations are laid out like so:
278 // 0 1
279 // 2 3
280 ASSERT(SIMD::Width == 4); // All cross-lane instructions will need care when using a different width
281 dst.move(i, SIMD::Float(Extract(src.Float(i), 1) - Extract(src.Float(i), 0)));
282 break;
283 case spv::OpDPdy:
284 case spv::OpDPdyCoarse:
285 dst.move(i, SIMD::Float(Extract(src.Float(i), 2) - Extract(src.Float(i), 0)));
286 break;
287 case spv::OpFwidth:
288 case spv::OpFwidthCoarse:
289 dst.move(i, SIMD::Float(Abs(Extract(src.Float(i), 1) - Extract(src.Float(i), 0)) + Abs(Extract(src.Float(i), 2) - Extract(src.Float(i), 0))));
290 break;
291 case spv::OpDPdxFine:
292 {
293 auto firstRow = Extract(src.Float(i), 1) - Extract(src.Float(i), 0);
294 auto secondRow = Extract(src.Float(i), 3) - Extract(src.Float(i), 2);
295 SIMD::Float v = SIMD::Float(firstRow);
296 v = Insert(v, secondRow, 2);
297 v = Insert(v, secondRow, 3);
298 dst.move(i, v);
299 }
300 break;
301 case spv::OpDPdyFine:
302 {
303 auto firstColumn = Extract(src.Float(i), 2) - Extract(src.Float(i), 0);
304 auto secondColumn = Extract(src.Float(i), 3) - Extract(src.Float(i), 1);
305 SIMD::Float v = SIMD::Float(firstColumn);
306 v = Insert(v, secondColumn, 1);
307 v = Insert(v, secondColumn, 3);
308 dst.move(i, v);
309 }
310 break;
311 case spv::OpFwidthFine:
312 {
313 auto firstRow = Extract(src.Float(i), 1) - Extract(src.Float(i), 0);
314 auto secondRow = Extract(src.Float(i), 3) - Extract(src.Float(i), 2);
315 SIMD::Float dpdx = SIMD::Float(firstRow);
316 dpdx = Insert(dpdx, secondRow, 2);
317 dpdx = Insert(dpdx, secondRow, 3);
318 auto firstColumn = Extract(src.Float(i), 2) - Extract(src.Float(i), 0);
319 auto secondColumn = Extract(src.Float(i), 3) - Extract(src.Float(i), 1);
320 SIMD::Float dpdy = SIMD::Float(firstColumn);
321 dpdy = Insert(dpdy, secondColumn, 1);
322 dpdy = Insert(dpdy, secondColumn, 3);
323 dst.move(i, Abs(dpdx) + Abs(dpdy));
324 }
325 break;
326 case spv::OpQuantizeToF16:
327 {
328 // Note: keep in sync with the specialization constant version in EvalSpecConstantUnaryOp
329 auto abs = Abs(src.Float(i));
330 auto sign = src.Int(i) & SIMD::Int(0x80000000);
331 auto isZero = CmpLT(abs, SIMD::Float(0.000061035f));
332 auto isInf = CmpGT(abs, SIMD::Float(65504.0f));
333 auto isNaN = IsNan(abs);
334 auto isInfOrNan = isInf | isNaN;
335 SIMD::Int v = src.Int(i) & SIMD::Int(0xFFFFE000);
336 v &= ~isZero | SIMD::Int(0x80000000);
337 v = sign | (isInfOrNan & SIMD::Int(0x7F800000)) | (~isInfOrNan & v);
338 v |= isNaN & SIMD::Int(0x400000);
339 dst.move(i, v);
340 }
341 break;
342 default:
343 UNREACHABLE("%s", shader.OpcodeName(insn.opcode()));
344 }
345 }
346 }
347
EmitBinaryOp(Spirv::InsnIterator insn)348 void SpirvEmitter::EmitBinaryOp(Spirv::InsnIterator insn)
349 {
350 auto &type = shader.getType(insn.resultTypeId());
351 auto &dst = createIntermediate(insn.resultId(), type.componentCount);
352 auto &lhsType = shader.getObjectType(insn.word(3));
353 auto lhs = Operand(shader, *this, insn.word(3));
354 auto rhs = Operand(shader, *this, insn.word(4));
355
356 for(auto i = 0u; i < lhsType.componentCount; i++)
357 {
358 switch(insn.opcode())
359 {
360 case spv::OpIAdd:
361 dst.move(i, lhs.Int(i) + rhs.Int(i));
362 break;
363 case spv::OpISub:
364 dst.move(i, lhs.Int(i) - rhs.Int(i));
365 break;
366 case spv::OpIMul:
367 dst.move(i, lhs.Int(i) * rhs.Int(i));
368 break;
369 case spv::OpSDiv:
370 {
371 SIMD::Int a = lhs.Int(i);
372 SIMD::Int b = rhs.Int(i);
373 b = b | CmpEQ(b, SIMD::Int(0)); // prevent divide-by-zero
374 a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1))); // prevent integer overflow
375 dst.move(i, a / b);
376 }
377 break;
378 case spv::OpUDiv:
379 {
380 auto zeroMask = As<SIMD::UInt>(CmpEQ(rhs.Int(i), SIMD::Int(0)));
381 dst.move(i, lhs.UInt(i) / (rhs.UInt(i) | zeroMask));
382 }
383 break;
384 case spv::OpSRem:
385 {
386 SIMD::Int a = lhs.Int(i);
387 SIMD::Int b = rhs.Int(i);
388 b = b | CmpEQ(b, SIMD::Int(0)); // prevent divide-by-zero
389 a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1))); // prevent integer overflow
390 dst.move(i, a % b);
391 }
392 break;
393 case spv::OpSMod:
394 {
395 SIMD::Int a = lhs.Int(i);
396 SIMD::Int b = rhs.Int(i);
397 b = b | CmpEQ(b, SIMD::Int(0)); // prevent divide-by-zero
398 a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1))); // prevent integer overflow
399 auto mod = a % b;
400 // If a and b have opposite signs, the remainder operation takes
401 // the sign from a but OpSMod is supposed to take the sign of b.
402 // Adding b will ensure that the result has the correct sign and
403 // that it is still congruent to a modulo b.
404 //
405 // See also http://mathforum.org/library/drmath/view/52343.html
406 auto signDiff = CmpNEQ(CmpGE(a, SIMD::Int(0)), CmpGE(b, SIMD::Int(0)));
407 auto fixedMod = mod + (b & CmpNEQ(mod, SIMD::Int(0)) & signDiff);
408 dst.move(i, As<SIMD::Float>(fixedMod));
409 }
410 break;
411 case spv::OpUMod:
412 {
413 auto zeroMask = As<SIMD::UInt>(CmpEQ(rhs.Int(i), SIMD::Int(0)));
414 dst.move(i, lhs.UInt(i) % (rhs.UInt(i) | zeroMask));
415 }
416 break;
417 case spv::OpIEqual:
418 case spv::OpLogicalEqual:
419 dst.move(i, CmpEQ(lhs.Int(i), rhs.Int(i)));
420 break;
421 case spv::OpINotEqual:
422 case spv::OpLogicalNotEqual:
423 dst.move(i, CmpNEQ(lhs.Int(i), rhs.Int(i)));
424 break;
425 case spv::OpUGreaterThan:
426 dst.move(i, CmpGT(lhs.UInt(i), rhs.UInt(i)));
427 break;
428 case spv::OpSGreaterThan:
429 dst.move(i, CmpGT(lhs.Int(i), rhs.Int(i)));
430 break;
431 case spv::OpUGreaterThanEqual:
432 dst.move(i, CmpGE(lhs.UInt(i), rhs.UInt(i)));
433 break;
434 case spv::OpSGreaterThanEqual:
435 dst.move(i, CmpGE(lhs.Int(i), rhs.Int(i)));
436 break;
437 case spv::OpULessThan:
438 dst.move(i, CmpLT(lhs.UInt(i), rhs.UInt(i)));
439 break;
440 case spv::OpSLessThan:
441 dst.move(i, CmpLT(lhs.Int(i), rhs.Int(i)));
442 break;
443 case spv::OpULessThanEqual:
444 dst.move(i, CmpLE(lhs.UInt(i), rhs.UInt(i)));
445 break;
446 case spv::OpSLessThanEqual:
447 dst.move(i, CmpLE(lhs.Int(i), rhs.Int(i)));
448 break;
449 case spv::OpFAdd:
450 dst.move(i, lhs.Float(i) + rhs.Float(i));
451 break;
452 case spv::OpFSub:
453 dst.move(i, lhs.Float(i) - rhs.Float(i));
454 break;
455 case spv::OpFMul:
456 dst.move(i, lhs.Float(i) * rhs.Float(i));
457 break;
458 case spv::OpFDiv:
459 // TODO(b/169760262): Optimize using reciprocal instructions (2.5 ULP).
460 // TODO(b/222218659): Optimize for RelaxedPrecision (2.5 ULP).
461 dst.move(i, lhs.Float(i) / rhs.Float(i));
462 break;
463 case spv::OpFMod:
464 // TODO(b/126873455): Inaccurate for values greater than 2^24.
465 // TODO(b/169760262): Optimize using reciprocal instructions.
466 // TODO(b/222218659): Optimize for RelaxedPrecision.
467 dst.move(i, lhs.Float(i) - rhs.Float(i) * Floor(lhs.Float(i) / rhs.Float(i)));
468 break;
469 case spv::OpFRem:
470 // TODO(b/169760262): Optimize using reciprocal instructions.
471 // TODO(b/222218659): Optimize for RelaxedPrecision.
472 dst.move(i, lhs.Float(i) % rhs.Float(i));
473 break;
474 case spv::OpFOrdEqual:
475 dst.move(i, CmpEQ(lhs.Float(i), rhs.Float(i)));
476 break;
477 case spv::OpFUnordEqual:
478 dst.move(i, CmpUEQ(lhs.Float(i), rhs.Float(i)));
479 break;
480 case spv::OpFOrdNotEqual:
481 dst.move(i, CmpNEQ(lhs.Float(i), rhs.Float(i)));
482 break;
483 case spv::OpFUnordNotEqual:
484 dst.move(i, CmpUNEQ(lhs.Float(i), rhs.Float(i)));
485 break;
486 case spv::OpFOrdLessThan:
487 dst.move(i, CmpLT(lhs.Float(i), rhs.Float(i)));
488 break;
489 case spv::OpFUnordLessThan:
490 dst.move(i, CmpULT(lhs.Float(i), rhs.Float(i)));
491 break;
492 case spv::OpFOrdGreaterThan:
493 dst.move(i, CmpGT(lhs.Float(i), rhs.Float(i)));
494 break;
495 case spv::OpFUnordGreaterThan:
496 dst.move(i, CmpUGT(lhs.Float(i), rhs.Float(i)));
497 break;
498 case spv::OpFOrdLessThanEqual:
499 dst.move(i, CmpLE(lhs.Float(i), rhs.Float(i)));
500 break;
501 case spv::OpFUnordLessThanEqual:
502 dst.move(i, CmpULE(lhs.Float(i), rhs.Float(i)));
503 break;
504 case spv::OpFOrdGreaterThanEqual:
505 dst.move(i, CmpGE(lhs.Float(i), rhs.Float(i)));
506 break;
507 case spv::OpFUnordGreaterThanEqual:
508 dst.move(i, CmpUGE(lhs.Float(i), rhs.Float(i)));
509 break;
510 case spv::OpShiftRightLogical:
511 dst.move(i, lhs.UInt(i) >> rhs.UInt(i));
512 break;
513 case spv::OpShiftRightArithmetic:
514 dst.move(i, lhs.Int(i) >> rhs.Int(i));
515 break;
516 case spv::OpShiftLeftLogical:
517 dst.move(i, lhs.UInt(i) << rhs.UInt(i));
518 break;
519 case spv::OpBitwiseOr:
520 case spv::OpLogicalOr:
521 dst.move(i, lhs.UInt(i) | rhs.UInt(i));
522 break;
523 case spv::OpBitwiseXor:
524 dst.move(i, lhs.UInt(i) ^ rhs.UInt(i));
525 break;
526 case spv::OpBitwiseAnd:
527 case spv::OpLogicalAnd:
528 dst.move(i, lhs.UInt(i) & rhs.UInt(i));
529 break;
530 case spv::OpSMulExtended:
531 // Extended ops: result is a structure containing two members of the same type as lhs & rhs.
532 // In our flat view then, component i is the i'th component of the first member;
533 // component i + N is the i'th component of the second member.
534 dst.move(i, lhs.Int(i) * rhs.Int(i));
535 dst.move(i + lhsType.componentCount, MulHigh(lhs.Int(i), rhs.Int(i)));
536 break;
537 case spv::OpUMulExtended:
538 dst.move(i, lhs.UInt(i) * rhs.UInt(i));
539 dst.move(i + lhsType.componentCount, MulHigh(lhs.UInt(i), rhs.UInt(i)));
540 break;
541 case spv::OpIAddCarry:
542 dst.move(i, lhs.UInt(i) + rhs.UInt(i));
543 dst.move(i + lhsType.componentCount, CmpLT(dst.UInt(i), lhs.UInt(i)) >> 31);
544 break;
545 case spv::OpISubBorrow:
546 dst.move(i, lhs.UInt(i) - rhs.UInt(i));
547 dst.move(i + lhsType.componentCount, CmpLT(lhs.UInt(i), rhs.UInt(i)) >> 31);
548 break;
549 default:
550 UNREACHABLE("%s", shader.OpcodeName(insn.opcode()));
551 }
552 }
553
554 SPIRV_SHADER_DBG("{0}: {1}", insn.word(2), dst);
555 SPIRV_SHADER_DBG("{0}: {1}", insn.word(3), lhs);
556 SPIRV_SHADER_DBG("{0}: {1}", insn.word(4), rhs);
557 }
558
EmitDot(Spirv::InsnIterator insn)559 void SpirvEmitter::EmitDot(Spirv::InsnIterator insn)
560 {
561 auto &type = shader.getType(insn.resultTypeId());
562 ASSERT(type.componentCount == 1);
563 auto &dst = createIntermediate(insn.resultId(), type.componentCount);
564 auto &lhsType = shader.getObjectType(insn.word(3));
565 auto lhs = Operand(shader, *this, insn.word(3));
566 auto rhs = Operand(shader, *this, insn.word(4));
567
568 auto opcode = insn.opcode();
569 switch(opcode)
570 {
571 case spv::OpDot:
572 dst.move(0, FDot(lhsType.componentCount, lhs, rhs));
573 break;
574 case spv::OpSDot:
575 dst.move(0, SDot(lhsType.componentCount, lhs, rhs, nullptr));
576 break;
577 case spv::OpUDot:
578 dst.move(0, UDot(lhsType.componentCount, lhs, rhs, nullptr));
579 break;
580 case spv::OpSUDot:
581 dst.move(0, SUDot(lhsType.componentCount, lhs, rhs, nullptr));
582 break;
583 case spv::OpSDotAccSat:
584 {
585 auto accum = Operand(shader, *this, insn.word(5));
586 dst.move(0, SDot(lhsType.componentCount, lhs, rhs, &accum));
587 }
588 break;
589 case spv::OpUDotAccSat:
590 {
591 auto accum = Operand(shader, *this, insn.word(5));
592 dst.move(0, UDot(lhsType.componentCount, lhs, rhs, &accum));
593 }
594 break;
595 case spv::OpSUDotAccSat:
596 {
597 auto accum = Operand(shader, *this, insn.word(5));
598 dst.move(0, SUDot(lhsType.componentCount, lhs, rhs, &accum));
599 }
600 break;
601 default:
602 UNREACHABLE("%s", shader.OpcodeName(opcode));
603 break;
604 }
605
606 SPIRV_SHADER_DBG("{0}: {1}", insn.resultId(), dst);
607 SPIRV_SHADER_DBG("{0}: {1}", insn.word(3), lhs);
608 SPIRV_SHADER_DBG("{0}: {1}", insn.word(4), rhs);
609 }
610
FDot(unsigned numComponents,const Operand & x,const Operand & y)611 SIMD::Float SpirvEmitter::FDot(unsigned numComponents, const Operand &x, const Operand &y)
612 {
613 SIMD::Float d = x.Float(0) * y.Float(0);
614
615 for(auto i = 1u; i < numComponents; i++)
616 {
617 d = MulAdd(x.Float(i), y.Float(i), d);
618 }
619
620 return d;
621 }
622
SDot(unsigned numComponents,const Operand & x,const Operand & y,const Operand * accum)623 SIMD::Int SpirvEmitter::SDot(unsigned numComponents, const Operand &x, const Operand &y, const Operand *accum)
624 {
625 SIMD::Int d(0);
626
627 if(numComponents == 1) // 4x8bit packed
628 {
629 numComponents = 4;
630 for(auto i = 0u; i < numComponents; i++)
631 {
632 Int4 xs(As<SByte4>(Extract(x.Int(0), i)));
633 Int4 ys(As<SByte4>(Extract(y.Int(0), i)));
634
635 Int4 xy = xs * ys;
636 rr::Int sum = Extract(xy, 0) + Extract(xy, 1) + Extract(xy, 2) + Extract(xy, 3);
637
638 d = Insert(d, sum, i);
639 }
640 }
641 else
642 {
643 d = x.Int(0) * y.Int(0);
644
645 for(auto i = 1u; i < numComponents; i++)
646 {
647 d += x.Int(i) * y.Int(i);
648 }
649 }
650
651 if(accum)
652 {
653 d = AddSat(d, accum->Int(0));
654 }
655
656 return d;
657 }
658
UDot(unsigned numComponents,const Operand & x,const Operand & y,const Operand * accum)659 SIMD::UInt SpirvEmitter::UDot(unsigned numComponents, const Operand &x, const Operand &y, const Operand *accum)
660 {
661 SIMD::UInt d(0);
662
663 if(numComponents == 1) // 4x8bit packed
664 {
665 numComponents = 4;
666 for(auto i = 0u; i < numComponents; i++)
667 {
668 Int4 xs(As<Byte4>(Extract(x.Int(0), i)));
669 Int4 ys(As<Byte4>(Extract(y.Int(0), i)));
670
671 UInt4 xy = xs * ys;
672 rr::UInt sum = Extract(xy, 0) + Extract(xy, 1) + Extract(xy, 2) + Extract(xy, 3);
673
674 d = Insert(d, sum, i);
675 }
676 }
677 else
678 {
679 d = x.UInt(0) * y.UInt(0);
680
681 for(auto i = 1u; i < numComponents; i++)
682 {
683 d += x.UInt(i) * y.UInt(i);
684 }
685 }
686
687 if(accum)
688 {
689 d = AddSat(d, accum->UInt(0));
690 }
691
692 return d;
693 }
694
SUDot(unsigned numComponents,const Operand & x,const Operand & y,const Operand * accum)695 SIMD::Int SpirvEmitter::SUDot(unsigned numComponents, const Operand &x, const Operand &y, const Operand *accum)
696 {
697 SIMD::Int d(0);
698
699 if(numComponents == 1) // 4x8bit packed
700 {
701 numComponents = 4;
702 for(auto i = 0u; i < numComponents; i++)
703 {
704 Int4 xs(As<SByte4>(Extract(x.Int(0), i)));
705 Int4 ys(As<Byte4>(Extract(y.Int(0), i)));
706
707 Int4 xy = xs * ys;
708 rr::Int sum = Extract(xy, 0) + Extract(xy, 1) + Extract(xy, 2) + Extract(xy, 3);
709
710 d = Insert(d, sum, i);
711 }
712 }
713 else
714 {
715 d = x.Int(0) * As<SIMD::Int>(y.UInt(0));
716
717 for(auto i = 1u; i < numComponents; i++)
718 {
719 d += x.Int(i) * As<SIMD::Int>(y.UInt(i));
720 }
721 }
722
723 if(accum)
724 {
725 d = AddSat(d, accum->Int(0));
726 }
727
728 return d;
729 }
730
AddSat(RValue<SIMD::Int> a,RValue<SIMD::Int> b)731 SIMD::Int SpirvEmitter::AddSat(RValue<SIMD::Int> a, RValue<SIMD::Int> b)
732 {
733 SIMD::Int sum = a + b;
734 SIMD::Int sSign = sum >> 31;
735 SIMD::Int aSign = a >> 31;
736 SIMD::Int bSign = b >> 31;
737
738 // Overflow happened if both numbers added have the same sign and the sum has a different sign
739 SIMD::Int oob = ~(aSign ^ bSign) & (aSign ^ sSign);
740 SIMD::Int overflow = oob & sSign;
741 SIMD::Int underflow = oob & aSign;
742
743 return (overflow & std::numeric_limits<int32_t>::max()) |
744 (underflow & std::numeric_limits<int32_t>::min()) |
745 (~oob & sum);
746 }
747
AddSat(RValue<SIMD::UInt> a,RValue<SIMD::UInt> b)748 SIMD::UInt SpirvEmitter::AddSat(RValue<SIMD::UInt> a, RValue<SIMD::UInt> b)
749 {
750 SIMD::UInt sum = a + b;
751
752 // Overflow happened if the sum of unsigned integers is smaller than either of the 2 numbers being added
753 // Note: CmpLT()'s return value is automatically set to UINT_MAX when true
754 return CmpLT(sum, a) | sum;
755 }
756
757 } // namespace sw