xref: /aosp_15_r20/external/swiftshader/src/Pipeline/SpirvShaderArithmetic.cpp (revision 03ce13f70fcc45d86ee91b7ee4cab1936a95046e)
1 // Copyright 2019 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //    http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "SpirvShader.hpp"
16 #include "SpirvShaderDebug.hpp"
17 
18 #include "ShaderCore.hpp"
19 
20 #include <spirv/unified1/spirv.hpp>
21 
22 #include <limits>
23 
24 namespace sw {
25 
EmitVectorTimesScalar(Spirv::InsnIterator insn)26 void SpirvEmitter::EmitVectorTimesScalar(Spirv::InsnIterator insn)
27 {
28 	auto &type = shader.getType(insn.resultTypeId());
29 	auto &dst = createIntermediate(insn.resultId(), type.componentCount);
30 	auto lhs = Operand(shader, *this, insn.word(3));
31 	auto rhs = Operand(shader, *this, insn.word(4));
32 
33 	for(auto i = 0u; i < type.componentCount; i++)
34 	{
35 		dst.move(i, lhs.Float(i) * rhs.Float(0));
36 	}
37 }
38 
EmitMatrixTimesVector(Spirv::InsnIterator insn)39 void SpirvEmitter::EmitMatrixTimesVector(Spirv::InsnIterator insn)
40 {
41 	auto &type = shader.getType(insn.resultTypeId());
42 	auto &dst = createIntermediate(insn.resultId(), type.componentCount);
43 	auto lhs = Operand(shader, *this, insn.word(3));
44 	auto rhs = Operand(shader, *this, insn.word(4));
45 
46 	for(auto i = 0u; i < type.componentCount; i++)
47 	{
48 		SIMD::Float v = lhs.Float(i) * rhs.Float(0);
49 		for(auto j = 1u; j < rhs.componentCount; j++)
50 		{
51 			v = MulAdd(lhs.Float(i + type.componentCount * j), rhs.Float(j), v);
52 		}
53 		dst.move(i, v);
54 	}
55 }
56 
EmitVectorTimesMatrix(Spirv::InsnIterator insn)57 void SpirvEmitter::EmitVectorTimesMatrix(Spirv::InsnIterator insn)
58 {
59 	auto &type = shader.getType(insn.resultTypeId());
60 	auto &dst = createIntermediate(insn.resultId(), type.componentCount);
61 	auto lhs = Operand(shader, *this, insn.word(3));
62 	auto rhs = Operand(shader, *this, insn.word(4));
63 
64 	for(auto i = 0u; i < type.componentCount; i++)
65 	{
66 		SIMD::Float v = lhs.Float(0) * rhs.Float(i * lhs.componentCount);
67 		for(auto j = 1u; j < lhs.componentCount; j++)
68 		{
69 			v = MulAdd(lhs.Float(j), rhs.Float(i * lhs.componentCount + j), v);
70 		}
71 		dst.move(i, v);
72 	}
73 }
74 
EmitMatrixTimesMatrix(Spirv::InsnIterator insn)75 void SpirvEmitter::EmitMatrixTimesMatrix(Spirv::InsnIterator insn)
76 {
77 	auto &type = shader.getType(insn.resultTypeId());
78 	auto &dst = createIntermediate(insn.resultId(), type.componentCount);
79 	auto lhs = Operand(shader, *this, insn.word(3));
80 	auto rhs = Operand(shader, *this, insn.word(4));
81 
82 	auto numColumns = type.definition.word(3);
83 	auto numRows = shader.getType(type.definition.word(2)).definition.word(3);
84 	auto numAdds = shader.getObjectType(insn.word(3)).definition.word(3);
85 
86 	for(auto row = 0u; row < numRows; row++)
87 	{
88 		for(auto col = 0u; col < numColumns; col++)
89 		{
90 			SIMD::Float v = lhs.Float(row) * rhs.Float(col * numAdds);
91 			for(auto i = 1u; i < numAdds; i++)
92 			{
93 				v = MulAdd(lhs.Float(i * numRows + row), rhs.Float(col * numAdds + i), v);
94 			}
95 			dst.move(numRows * col + row, v);
96 		}
97 	}
98 }
99 
EmitOuterProduct(Spirv::InsnIterator insn)100 void SpirvEmitter::EmitOuterProduct(Spirv::InsnIterator insn)
101 {
102 	auto &type = shader.getType(insn.resultTypeId());
103 	auto &dst = createIntermediate(insn.resultId(), type.componentCount);
104 	auto lhs = Operand(shader, *this, insn.word(3));
105 	auto rhs = Operand(shader, *this, insn.word(4));
106 
107 	auto numRows = lhs.componentCount;
108 	auto numCols = rhs.componentCount;
109 
110 	for(auto col = 0u; col < numCols; col++)
111 	{
112 		for(auto row = 0u; row < numRows; row++)
113 		{
114 			dst.move(col * numRows + row, lhs.Float(row) * rhs.Float(col));
115 		}
116 	}
117 }
118 
EmitTranspose(Spirv::InsnIterator insn)119 void SpirvEmitter::EmitTranspose(Spirv::InsnIterator insn)
120 {
121 	auto &type = shader.getType(insn.resultTypeId());
122 	auto &dst = createIntermediate(insn.resultId(), type.componentCount);
123 	auto mat = Operand(shader, *this, insn.word(3));
124 
125 	auto numCols = type.definition.word(3);
126 	auto numRows = shader.getType(type.definition.word(2)).componentCount;
127 
128 	for(auto col = 0u; col < numCols; col++)
129 	{
130 		for(auto row = 0u; row < numRows; row++)
131 		{
132 			dst.move(col * numRows + row, mat.Float(row * numCols + col));
133 		}
134 	}
135 }
136 
EmitBitcastPointer(Spirv::Object::ID resultID,Operand & src)137 void SpirvEmitter::EmitBitcastPointer(Spirv::Object::ID resultID, Operand &src)
138 {
139 	if(src.isPointer())  // Pointer -> Integer bits
140 	{
141 		if(sizeof(void *) == 4)  // 32-bit pointers
142 		{
143 			SIMD::UInt bits;
144 			src.Pointer().castTo(bits);
145 
146 			auto &dst = createIntermediate(resultID, 1);
147 			dst.move(0, bits);
148 		}
149 		else  // 64-bit pointers
150 		{
151 			ASSERT(sizeof(void *) == 8);
152 			// Casting a 64 bit pointer into 2 32bit integers
153 			auto &ptr = src.Pointer();
154 			SIMD::UInt lowerBits, upperBits;
155 			ptr.castTo(lowerBits, upperBits);
156 
157 			auto &dst = createIntermediate(resultID, 2);
158 			dst.move(0, lowerBits);
159 			dst.move(1, upperBits);
160 		}
161 	}
162 	else  // Integer bits -> Pointer
163 	{
164 		if(sizeof(void *) == 4)  // 32-bit pointers
165 		{
166 			createPointer(resultID, SIMD::Pointer(src.UInt(0)));
167 		}
168 		else  // 64-bit pointers
169 		{
170 			ASSERT(sizeof(void *) == 8);
171 			// Casting two 32-bit integers into a 64-bit pointer
172 			createPointer(resultID, SIMD::Pointer(src.UInt(0), src.UInt(1)));
173 		}
174 	}
175 }
176 
EmitUnaryOp(Spirv::InsnIterator insn)177 void SpirvEmitter::EmitUnaryOp(Spirv::InsnIterator insn)
178 {
179 	auto &type = shader.getType(insn.resultTypeId());
180 	auto src = Operand(shader, *this, insn.word(3));
181 
182 	bool dstIsPointer = shader.getObject(insn.resultId()).kind == Spirv::Object::Kind::Pointer;
183 	bool srcIsPointer = src.isPointer();
184 	if(srcIsPointer || dstIsPointer)
185 	{
186 		ASSERT(insn.opcode() == spv::OpBitcast);
187 		ASSERT(srcIsPointer || (type.componentCount == 1));  // When the ouput is a pointer, it's a single pointer
188 
189 		return EmitBitcastPointer(insn.resultId(), src);
190 	}
191 
192 	auto &dst = createIntermediate(insn.resultId(), type.componentCount);
193 
194 	for(auto i = 0u; i < type.componentCount; i++)
195 	{
196 		switch(insn.opcode())
197 		{
198 		case spv::OpNot:
199 		case spv::OpLogicalNot:  // logical not == bitwise not due to all-bits boolean representation
200 			dst.move(i, ~src.UInt(i));
201 			break;
202 		case spv::OpBitFieldInsert:
203 			{
204 				auto insert = Operand(shader, *this, insn.word(4)).UInt(i);
205 				auto offset = Operand(shader, *this, insn.word(5)).UInt(0);
206 				auto count = Operand(shader, *this, insn.word(6)).UInt(0);
207 				auto one = SIMD::UInt(1);
208 				auto v = src.UInt(i);
209 				auto mask = Bitmask32(offset + count) ^ Bitmask32(offset);
210 				dst.move(i, (v & ~mask) | ((insert << offset) & mask));
211 			}
212 			break;
213 		case spv::OpBitFieldSExtract:
214 		case spv::OpBitFieldUExtract:
215 			{
216 				auto offset = Operand(shader, *this, insn.word(4)).UInt(0);
217 				auto count = Operand(shader, *this, insn.word(5)).UInt(0);
218 				auto one = SIMD::UInt(1);
219 				auto v = src.UInt(i);
220 				SIMD::UInt out = (v >> offset) & Bitmask32(count);
221 				if(insn.opcode() == spv::OpBitFieldSExtract)
222 				{
223 					auto sign = out & NthBit32(count - one);
224 					auto sext = ~(sign - one);
225 					out |= sext;
226 				}
227 				dst.move(i, out);
228 			}
229 			break;
230 		case spv::OpBitReverse:
231 			{
232 				// TODO: Add an intrinsic to reactor. Even if there isn't a
233 				// single vector instruction, there may be target-dependent
234 				// ways to make this faster.
235 				// https://graphics.stanford.edu/~seander/bithacks.html#ReverseParallel
236 				SIMD::UInt v = src.UInt(i);
237 				v = ((v >> 1) & SIMD::UInt(0x55555555)) | ((v & SIMD::UInt(0x55555555)) << 1);
238 				v = ((v >> 2) & SIMD::UInt(0x33333333)) | ((v & SIMD::UInt(0x33333333)) << 2);
239 				v = ((v >> 4) & SIMD::UInt(0x0F0F0F0F)) | ((v & SIMD::UInt(0x0F0F0F0F)) << 4);
240 				v = ((v >> 8) & SIMD::UInt(0x00FF00FF)) | ((v & SIMD::UInt(0x00FF00FF)) << 8);
241 				v = (v >> 16) | (v << 16);
242 				dst.move(i, v);
243 			}
244 			break;
245 		case spv::OpBitCount:
246 			dst.move(i, CountBits(src.UInt(i)));
247 			break;
248 		case spv::OpSNegate:
249 			dst.move(i, -src.Int(i));
250 			break;
251 		case spv::OpFNegate:
252 			dst.move(i, -src.Float(i));
253 			break;
254 		case spv::OpConvertFToU:
255 			dst.move(i, SIMD::UInt(src.Float(i)));
256 			break;
257 		case spv::OpConvertFToS:
258 			dst.move(i, SIMD::Int(src.Float(i)));
259 			break;
260 		case spv::OpConvertSToF:
261 			dst.move(i, SIMD::Float(src.Int(i)));
262 			break;
263 		case spv::OpConvertUToF:
264 			dst.move(i, SIMD::Float(src.UInt(i)));
265 			break;
266 		case spv::OpBitcast:
267 			dst.move(i, src.Float(i));
268 			break;
269 		case spv::OpIsInf:
270 			dst.move(i, IsInf(src.Float(i)));
271 			break;
272 		case spv::OpIsNan:
273 			dst.move(i, IsNan(src.Float(i)));
274 			break;
275 		case spv::OpDPdx:
276 		case spv::OpDPdxCoarse:
277 			// Derivative instructions: FS invocations are laid out like so:
278 			//    0 1
279 			//    2 3
280 			ASSERT(SIMD::Width == 4);  // All cross-lane instructions will need care when using a different width
281 			dst.move(i, SIMD::Float(Extract(src.Float(i), 1) - Extract(src.Float(i), 0)));
282 			break;
283 		case spv::OpDPdy:
284 		case spv::OpDPdyCoarse:
285 			dst.move(i, SIMD::Float(Extract(src.Float(i), 2) - Extract(src.Float(i), 0)));
286 			break;
287 		case spv::OpFwidth:
288 		case spv::OpFwidthCoarse:
289 			dst.move(i, SIMD::Float(Abs(Extract(src.Float(i), 1) - Extract(src.Float(i), 0)) + Abs(Extract(src.Float(i), 2) - Extract(src.Float(i), 0))));
290 			break;
291 		case spv::OpDPdxFine:
292 			{
293 				auto firstRow = Extract(src.Float(i), 1) - Extract(src.Float(i), 0);
294 				auto secondRow = Extract(src.Float(i), 3) - Extract(src.Float(i), 2);
295 				SIMD::Float v = SIMD::Float(firstRow);
296 				v = Insert(v, secondRow, 2);
297 				v = Insert(v, secondRow, 3);
298 				dst.move(i, v);
299 			}
300 			break;
301 		case spv::OpDPdyFine:
302 			{
303 				auto firstColumn = Extract(src.Float(i), 2) - Extract(src.Float(i), 0);
304 				auto secondColumn = Extract(src.Float(i), 3) - Extract(src.Float(i), 1);
305 				SIMD::Float v = SIMD::Float(firstColumn);
306 				v = Insert(v, secondColumn, 1);
307 				v = Insert(v, secondColumn, 3);
308 				dst.move(i, v);
309 			}
310 			break;
311 		case spv::OpFwidthFine:
312 			{
313 				auto firstRow = Extract(src.Float(i), 1) - Extract(src.Float(i), 0);
314 				auto secondRow = Extract(src.Float(i), 3) - Extract(src.Float(i), 2);
315 				SIMD::Float dpdx = SIMD::Float(firstRow);
316 				dpdx = Insert(dpdx, secondRow, 2);
317 				dpdx = Insert(dpdx, secondRow, 3);
318 				auto firstColumn = Extract(src.Float(i), 2) - Extract(src.Float(i), 0);
319 				auto secondColumn = Extract(src.Float(i), 3) - Extract(src.Float(i), 1);
320 				SIMD::Float dpdy = SIMD::Float(firstColumn);
321 				dpdy = Insert(dpdy, secondColumn, 1);
322 				dpdy = Insert(dpdy, secondColumn, 3);
323 				dst.move(i, Abs(dpdx) + Abs(dpdy));
324 			}
325 			break;
326 		case spv::OpQuantizeToF16:
327 			{
328 				// Note: keep in sync with the specialization constant version in EvalSpecConstantUnaryOp
329 				auto abs = Abs(src.Float(i));
330 				auto sign = src.Int(i) & SIMD::Int(0x80000000);
331 				auto isZero = CmpLT(abs, SIMD::Float(0.000061035f));
332 				auto isInf = CmpGT(abs, SIMD::Float(65504.0f));
333 				auto isNaN = IsNan(abs);
334 				auto isInfOrNan = isInf | isNaN;
335 				SIMD::Int v = src.Int(i) & SIMD::Int(0xFFFFE000);
336 				v &= ~isZero | SIMD::Int(0x80000000);
337 				v = sign | (isInfOrNan & SIMD::Int(0x7F800000)) | (~isInfOrNan & v);
338 				v |= isNaN & SIMD::Int(0x400000);
339 				dst.move(i, v);
340 			}
341 			break;
342 		default:
343 			UNREACHABLE("%s", shader.OpcodeName(insn.opcode()));
344 		}
345 	}
346 }
347 
EmitBinaryOp(Spirv::InsnIterator insn)348 void SpirvEmitter::EmitBinaryOp(Spirv::InsnIterator insn)
349 {
350 	auto &type = shader.getType(insn.resultTypeId());
351 	auto &dst = createIntermediate(insn.resultId(), type.componentCount);
352 	auto &lhsType = shader.getObjectType(insn.word(3));
353 	auto lhs = Operand(shader, *this, insn.word(3));
354 	auto rhs = Operand(shader, *this, insn.word(4));
355 
356 	for(auto i = 0u; i < lhsType.componentCount; i++)
357 	{
358 		switch(insn.opcode())
359 		{
360 		case spv::OpIAdd:
361 			dst.move(i, lhs.Int(i) + rhs.Int(i));
362 			break;
363 		case spv::OpISub:
364 			dst.move(i, lhs.Int(i) - rhs.Int(i));
365 			break;
366 		case spv::OpIMul:
367 			dst.move(i, lhs.Int(i) * rhs.Int(i));
368 			break;
369 		case spv::OpSDiv:
370 			{
371 				SIMD::Int a = lhs.Int(i);
372 				SIMD::Int b = rhs.Int(i);
373 				b = b | CmpEQ(b, SIMD::Int(0));                                       // prevent divide-by-zero
374 				a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1)));  // prevent integer overflow
375 				dst.move(i, a / b);
376 			}
377 			break;
378 		case spv::OpUDiv:
379 			{
380 				auto zeroMask = As<SIMD::UInt>(CmpEQ(rhs.Int(i), SIMD::Int(0)));
381 				dst.move(i, lhs.UInt(i) / (rhs.UInt(i) | zeroMask));
382 			}
383 			break;
384 		case spv::OpSRem:
385 			{
386 				SIMD::Int a = lhs.Int(i);
387 				SIMD::Int b = rhs.Int(i);
388 				b = b | CmpEQ(b, SIMD::Int(0));                                       // prevent divide-by-zero
389 				a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1)));  // prevent integer overflow
390 				dst.move(i, a % b);
391 			}
392 			break;
393 		case spv::OpSMod:
394 			{
395 				SIMD::Int a = lhs.Int(i);
396 				SIMD::Int b = rhs.Int(i);
397 				b = b | CmpEQ(b, SIMD::Int(0));                                       // prevent divide-by-zero
398 				a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1)));  // prevent integer overflow
399 				auto mod = a % b;
400 				// If a and b have opposite signs, the remainder operation takes
401 				// the sign from a but OpSMod is supposed to take the sign of b.
402 				// Adding b will ensure that the result has the correct sign and
403 				// that it is still congruent to a modulo b.
404 				//
405 				// See also http://mathforum.org/library/drmath/view/52343.html
406 				auto signDiff = CmpNEQ(CmpGE(a, SIMD::Int(0)), CmpGE(b, SIMD::Int(0)));
407 				auto fixedMod = mod + (b & CmpNEQ(mod, SIMD::Int(0)) & signDiff);
408 				dst.move(i, As<SIMD::Float>(fixedMod));
409 			}
410 			break;
411 		case spv::OpUMod:
412 			{
413 				auto zeroMask = As<SIMD::UInt>(CmpEQ(rhs.Int(i), SIMD::Int(0)));
414 				dst.move(i, lhs.UInt(i) % (rhs.UInt(i) | zeroMask));
415 			}
416 			break;
417 		case spv::OpIEqual:
418 		case spv::OpLogicalEqual:
419 			dst.move(i, CmpEQ(lhs.Int(i), rhs.Int(i)));
420 			break;
421 		case spv::OpINotEqual:
422 		case spv::OpLogicalNotEqual:
423 			dst.move(i, CmpNEQ(lhs.Int(i), rhs.Int(i)));
424 			break;
425 		case spv::OpUGreaterThan:
426 			dst.move(i, CmpGT(lhs.UInt(i), rhs.UInt(i)));
427 			break;
428 		case spv::OpSGreaterThan:
429 			dst.move(i, CmpGT(lhs.Int(i), rhs.Int(i)));
430 			break;
431 		case spv::OpUGreaterThanEqual:
432 			dst.move(i, CmpGE(lhs.UInt(i), rhs.UInt(i)));
433 			break;
434 		case spv::OpSGreaterThanEqual:
435 			dst.move(i, CmpGE(lhs.Int(i), rhs.Int(i)));
436 			break;
437 		case spv::OpULessThan:
438 			dst.move(i, CmpLT(lhs.UInt(i), rhs.UInt(i)));
439 			break;
440 		case spv::OpSLessThan:
441 			dst.move(i, CmpLT(lhs.Int(i), rhs.Int(i)));
442 			break;
443 		case spv::OpULessThanEqual:
444 			dst.move(i, CmpLE(lhs.UInt(i), rhs.UInt(i)));
445 			break;
446 		case spv::OpSLessThanEqual:
447 			dst.move(i, CmpLE(lhs.Int(i), rhs.Int(i)));
448 			break;
449 		case spv::OpFAdd:
450 			dst.move(i, lhs.Float(i) + rhs.Float(i));
451 			break;
452 		case spv::OpFSub:
453 			dst.move(i, lhs.Float(i) - rhs.Float(i));
454 			break;
455 		case spv::OpFMul:
456 			dst.move(i, lhs.Float(i) * rhs.Float(i));
457 			break;
458 		case spv::OpFDiv:
459 			// TODO(b/169760262): Optimize using reciprocal instructions (2.5 ULP).
460 			// TODO(b/222218659): Optimize for RelaxedPrecision (2.5 ULP).
461 			dst.move(i, lhs.Float(i) / rhs.Float(i));
462 			break;
463 		case spv::OpFMod:
464 			// TODO(b/126873455): Inaccurate for values greater than 2^24.
465 			// TODO(b/169760262): Optimize using reciprocal instructions.
466 			// TODO(b/222218659): Optimize for RelaxedPrecision.
467 			dst.move(i, lhs.Float(i) - rhs.Float(i) * Floor(lhs.Float(i) / rhs.Float(i)));
468 			break;
469 		case spv::OpFRem:
470 			// TODO(b/169760262): Optimize using reciprocal instructions.
471 			// TODO(b/222218659): Optimize for RelaxedPrecision.
472 			dst.move(i, lhs.Float(i) % rhs.Float(i));
473 			break;
474 		case spv::OpFOrdEqual:
475 			dst.move(i, CmpEQ(lhs.Float(i), rhs.Float(i)));
476 			break;
477 		case spv::OpFUnordEqual:
478 			dst.move(i, CmpUEQ(lhs.Float(i), rhs.Float(i)));
479 			break;
480 		case spv::OpFOrdNotEqual:
481 			dst.move(i, CmpNEQ(lhs.Float(i), rhs.Float(i)));
482 			break;
483 		case spv::OpFUnordNotEqual:
484 			dst.move(i, CmpUNEQ(lhs.Float(i), rhs.Float(i)));
485 			break;
486 		case spv::OpFOrdLessThan:
487 			dst.move(i, CmpLT(lhs.Float(i), rhs.Float(i)));
488 			break;
489 		case spv::OpFUnordLessThan:
490 			dst.move(i, CmpULT(lhs.Float(i), rhs.Float(i)));
491 			break;
492 		case spv::OpFOrdGreaterThan:
493 			dst.move(i, CmpGT(lhs.Float(i), rhs.Float(i)));
494 			break;
495 		case spv::OpFUnordGreaterThan:
496 			dst.move(i, CmpUGT(lhs.Float(i), rhs.Float(i)));
497 			break;
498 		case spv::OpFOrdLessThanEqual:
499 			dst.move(i, CmpLE(lhs.Float(i), rhs.Float(i)));
500 			break;
501 		case spv::OpFUnordLessThanEqual:
502 			dst.move(i, CmpULE(lhs.Float(i), rhs.Float(i)));
503 			break;
504 		case spv::OpFOrdGreaterThanEqual:
505 			dst.move(i, CmpGE(lhs.Float(i), rhs.Float(i)));
506 			break;
507 		case spv::OpFUnordGreaterThanEqual:
508 			dst.move(i, CmpUGE(lhs.Float(i), rhs.Float(i)));
509 			break;
510 		case spv::OpShiftRightLogical:
511 			dst.move(i, lhs.UInt(i) >> rhs.UInt(i));
512 			break;
513 		case spv::OpShiftRightArithmetic:
514 			dst.move(i, lhs.Int(i) >> rhs.Int(i));
515 			break;
516 		case spv::OpShiftLeftLogical:
517 			dst.move(i, lhs.UInt(i) << rhs.UInt(i));
518 			break;
519 		case spv::OpBitwiseOr:
520 		case spv::OpLogicalOr:
521 			dst.move(i, lhs.UInt(i) | rhs.UInt(i));
522 			break;
523 		case spv::OpBitwiseXor:
524 			dst.move(i, lhs.UInt(i) ^ rhs.UInt(i));
525 			break;
526 		case spv::OpBitwiseAnd:
527 		case spv::OpLogicalAnd:
528 			dst.move(i, lhs.UInt(i) & rhs.UInt(i));
529 			break;
530 		case spv::OpSMulExtended:
531 			// Extended ops: result is a structure containing two members of the same type as lhs & rhs.
532 			// In our flat view then, component i is the i'th component of the first member;
533 			// component i + N is the i'th component of the second member.
534 			dst.move(i, lhs.Int(i) * rhs.Int(i));
535 			dst.move(i + lhsType.componentCount, MulHigh(lhs.Int(i), rhs.Int(i)));
536 			break;
537 		case spv::OpUMulExtended:
538 			dst.move(i, lhs.UInt(i) * rhs.UInt(i));
539 			dst.move(i + lhsType.componentCount, MulHigh(lhs.UInt(i), rhs.UInt(i)));
540 			break;
541 		case spv::OpIAddCarry:
542 			dst.move(i, lhs.UInt(i) + rhs.UInt(i));
543 			dst.move(i + lhsType.componentCount, CmpLT(dst.UInt(i), lhs.UInt(i)) >> 31);
544 			break;
545 		case spv::OpISubBorrow:
546 			dst.move(i, lhs.UInt(i) - rhs.UInt(i));
547 			dst.move(i + lhsType.componentCount, CmpLT(lhs.UInt(i), rhs.UInt(i)) >> 31);
548 			break;
549 		default:
550 			UNREACHABLE("%s", shader.OpcodeName(insn.opcode()));
551 		}
552 	}
553 
554 	SPIRV_SHADER_DBG("{0}: {1}", insn.word(2), dst);
555 	SPIRV_SHADER_DBG("{0}: {1}", insn.word(3), lhs);
556 	SPIRV_SHADER_DBG("{0}: {1}", insn.word(4), rhs);
557 }
558 
EmitDot(Spirv::InsnIterator insn)559 void SpirvEmitter::EmitDot(Spirv::InsnIterator insn)
560 {
561 	auto &type = shader.getType(insn.resultTypeId());
562 	ASSERT(type.componentCount == 1);
563 	auto &dst = createIntermediate(insn.resultId(), type.componentCount);
564 	auto &lhsType = shader.getObjectType(insn.word(3));
565 	auto lhs = Operand(shader, *this, insn.word(3));
566 	auto rhs = Operand(shader, *this, insn.word(4));
567 
568 	auto opcode = insn.opcode();
569 	switch(opcode)
570 	{
571 	case spv::OpDot:
572 		dst.move(0, FDot(lhsType.componentCount, lhs, rhs));
573 		break;
574 	case spv::OpSDot:
575 		dst.move(0, SDot(lhsType.componentCount, lhs, rhs, nullptr));
576 		break;
577 	case spv::OpUDot:
578 		dst.move(0, UDot(lhsType.componentCount, lhs, rhs, nullptr));
579 		break;
580 	case spv::OpSUDot:
581 		dst.move(0, SUDot(lhsType.componentCount, lhs, rhs, nullptr));
582 		break;
583 	case spv::OpSDotAccSat:
584 		{
585 			auto accum = Operand(shader, *this, insn.word(5));
586 			dst.move(0, SDot(lhsType.componentCount, lhs, rhs, &accum));
587 		}
588 		break;
589 	case spv::OpUDotAccSat:
590 		{
591 			auto accum = Operand(shader, *this, insn.word(5));
592 			dst.move(0, UDot(lhsType.componentCount, lhs, rhs, &accum));
593 		}
594 		break;
595 	case spv::OpSUDotAccSat:
596 		{
597 			auto accum = Operand(shader, *this, insn.word(5));
598 			dst.move(0, SUDot(lhsType.componentCount, lhs, rhs, &accum));
599 		}
600 		break;
601 	default:
602 		UNREACHABLE("%s", shader.OpcodeName(opcode));
603 		break;
604 	}
605 
606 	SPIRV_SHADER_DBG("{0}: {1}", insn.resultId(), dst);
607 	SPIRV_SHADER_DBG("{0}: {1}", insn.word(3), lhs);
608 	SPIRV_SHADER_DBG("{0}: {1}", insn.word(4), rhs);
609 }
610 
FDot(unsigned numComponents,const Operand & x,const Operand & y)611 SIMD::Float SpirvEmitter::FDot(unsigned numComponents, const Operand &x, const Operand &y)
612 {
613 	SIMD::Float d = x.Float(0) * y.Float(0);
614 
615 	for(auto i = 1u; i < numComponents; i++)
616 	{
617 		d = MulAdd(x.Float(i), y.Float(i), d);
618 	}
619 
620 	return d;
621 }
622 
SDot(unsigned numComponents,const Operand & x,const Operand & y,const Operand * accum)623 SIMD::Int SpirvEmitter::SDot(unsigned numComponents, const Operand &x, const Operand &y, const Operand *accum)
624 {
625 	SIMD::Int d(0);
626 
627 	if(numComponents == 1)  // 4x8bit packed
628 	{
629 		numComponents = 4;
630 		for(auto i = 0u; i < numComponents; i++)
631 		{
632 			Int4 xs(As<SByte4>(Extract(x.Int(0), i)));
633 			Int4 ys(As<SByte4>(Extract(y.Int(0), i)));
634 
635 			Int4 xy = xs * ys;
636 			rr::Int sum = Extract(xy, 0) + Extract(xy, 1) + Extract(xy, 2) + Extract(xy, 3);
637 
638 			d = Insert(d, sum, i);
639 		}
640 	}
641 	else
642 	{
643 		d = x.Int(0) * y.Int(0);
644 
645 		for(auto i = 1u; i < numComponents; i++)
646 		{
647 			d += x.Int(i) * y.Int(i);
648 		}
649 	}
650 
651 	if(accum)
652 	{
653 		d = AddSat(d, accum->Int(0));
654 	}
655 
656 	return d;
657 }
658 
UDot(unsigned numComponents,const Operand & x,const Operand & y,const Operand * accum)659 SIMD::UInt SpirvEmitter::UDot(unsigned numComponents, const Operand &x, const Operand &y, const Operand *accum)
660 {
661 	SIMD::UInt d(0);
662 
663 	if(numComponents == 1)  // 4x8bit packed
664 	{
665 		numComponents = 4;
666 		for(auto i = 0u; i < numComponents; i++)
667 		{
668 			Int4 xs(As<Byte4>(Extract(x.Int(0), i)));
669 			Int4 ys(As<Byte4>(Extract(y.Int(0), i)));
670 
671 			UInt4 xy = xs * ys;
672 			rr::UInt sum = Extract(xy, 0) + Extract(xy, 1) + Extract(xy, 2) + Extract(xy, 3);
673 
674 			d = Insert(d, sum, i);
675 		}
676 	}
677 	else
678 	{
679 		d = x.UInt(0) * y.UInt(0);
680 
681 		for(auto i = 1u; i < numComponents; i++)
682 		{
683 			d += x.UInt(i) * y.UInt(i);
684 		}
685 	}
686 
687 	if(accum)
688 	{
689 		d = AddSat(d, accum->UInt(0));
690 	}
691 
692 	return d;
693 }
694 
SUDot(unsigned numComponents,const Operand & x,const Operand & y,const Operand * accum)695 SIMD::Int SpirvEmitter::SUDot(unsigned numComponents, const Operand &x, const Operand &y, const Operand *accum)
696 {
697 	SIMD::Int d(0);
698 
699 	if(numComponents == 1)  // 4x8bit packed
700 	{
701 		numComponents = 4;
702 		for(auto i = 0u; i < numComponents; i++)
703 		{
704 			Int4 xs(As<SByte4>(Extract(x.Int(0), i)));
705 			Int4 ys(As<Byte4>(Extract(y.Int(0), i)));
706 
707 			Int4 xy = xs * ys;
708 			rr::Int sum = Extract(xy, 0) + Extract(xy, 1) + Extract(xy, 2) + Extract(xy, 3);
709 
710 			d = Insert(d, sum, i);
711 		}
712 	}
713 	else
714 	{
715 		d = x.Int(0) * As<SIMD::Int>(y.UInt(0));
716 
717 		for(auto i = 1u; i < numComponents; i++)
718 		{
719 			d += x.Int(i) * As<SIMD::Int>(y.UInt(i));
720 		}
721 	}
722 
723 	if(accum)
724 	{
725 		d = AddSat(d, accum->Int(0));
726 	}
727 
728 	return d;
729 }
730 
AddSat(RValue<SIMD::Int> a,RValue<SIMD::Int> b)731 SIMD::Int SpirvEmitter::AddSat(RValue<SIMD::Int> a, RValue<SIMD::Int> b)
732 {
733 	SIMD::Int sum = a + b;
734 	SIMD::Int sSign = sum >> 31;
735 	SIMD::Int aSign = a >> 31;
736 	SIMD::Int bSign = b >> 31;
737 
738 	// Overflow happened if both numbers added have the same sign and the sum has a different sign
739 	SIMD::Int oob = ~(aSign ^ bSign) & (aSign ^ sSign);
740 	SIMD::Int overflow = oob & sSign;
741 	SIMD::Int underflow = oob & aSign;
742 
743 	return (overflow & std::numeric_limits<int32_t>::max()) |
744 	       (underflow & std::numeric_limits<int32_t>::min()) |
745 	       (~oob & sum);
746 }
747 
AddSat(RValue<SIMD::UInt> a,RValue<SIMD::UInt> b)748 SIMD::UInt SpirvEmitter::AddSat(RValue<SIMD::UInt> a, RValue<SIMD::UInt> b)
749 {
750 	SIMD::UInt sum = a + b;
751 
752 	// Overflow happened if the sum of unsigned integers is smaller than either of the 2 numbers being added
753 	// Note: CmpLT()'s return value is automatically set to UINT_MAX when true
754 	return CmpLT(sum, a) | sum;
755 }
756 
757 }  // namespace sw