1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // OutputSPIRV: Generate SPIR-V from the AST.
7 //
8
9 #include "compiler/translator/spirv/OutputSPIRV.h"
10
11 #include "angle_gl.h"
12 #include "common/debug.h"
13 #include "common/mathutil.h"
14 #include "common/spirv/spirv_instruction_builder_autogen.h"
15 #include "compiler/translator/Compiler.h"
16 #include "compiler/translator/StaticType.h"
17 #include "compiler/translator/spirv/BuildSPIRV.h"
18 #include "compiler/translator/tree_util/FindPreciseNodes.h"
19 #include "compiler/translator/tree_util/IntermTraverse.h"
20
21 #include <cfloat>
22
23 // Extended instructions
24 namespace spv
25 {
26 #include <spirv/unified1/GLSL.std.450.h>
27 }
28
29 // SPIR-V tools include for disassembly
30 #include <spirv-tools/libspirv.hpp>
31
32 // Enable this for debug logging of pre-transform SPIR-V:
33 #if !defined(ANGLE_DEBUG_SPIRV_GENERATION)
34 # define ANGLE_DEBUG_SPIRV_GENERATION 0
35 #endif // !defined(ANGLE_DEBUG_SPIRV_GENERATION)
36
37 namespace sh
38 {
39 namespace
40 {
41 // A struct to hold either SPIR-V ids or literal constants. If id is not valid, a literal is
42 // assumed.
43 struct SpirvIdOrLiteral
44 {
45 SpirvIdOrLiteral() = default;
SpirvIdOrLiteralsh::__anonce4844100111::SpirvIdOrLiteral46 SpirvIdOrLiteral(const spirv::IdRef idIn) : id(idIn) {}
SpirvIdOrLiteralsh::__anonce4844100111::SpirvIdOrLiteral47 SpirvIdOrLiteral(const spirv::LiteralInteger literalIn) : literal(literalIn) {}
48
49 spirv::IdRef id;
50 spirv::LiteralInteger literal;
51 };
52
53 // A data structure to facilitate generating array indexing, block field selection, swizzle and
54 // such. Used in conjunction with NodeData which includes the access chain's baseId and idList.
55 //
56 // - rvalue[literal].field[literal] generates OpCompositeExtract
57 // - rvalue.x generates OpCompositeExtract
58 // - rvalue.xyz generates OpVectorShuffle
59 // - rvalue.xyz[i] generates OpVectorExtractDynamic (xyz[i] itself generates an
60 // OpVectorExtractDynamic as well)
61 // - rvalue[i].field[j] generates a temp variable OpStore'ing rvalue and then generating an
62 // OpAccessChain and OpLoad
63 //
64 // - lvalue[i].field[j].x generates OpAccessChain and OpStore
65 // - lvalue.xyz generates an OpLoad followed by OpVectorShuffle and OpStore
66 // - lvalue.xyz[i] generates OpAccessChain and OpStore (xyz[i] itself generates an
67 // OpVectorExtractDynamic as well)
68 //
69 // storageClass == Max implies an rvalue.
70 //
71 struct AccessChain
72 {
73 // The storage class for lvalues. If Max, it's an rvalue.
74 spv::StorageClass storageClass = spv::StorageClassMax;
75 // If the access chain ends in swizzle, the swizzle components are specified here. Swizzles
76 // select multiple components so need special treatment when used as lvalue.
77 std::vector<uint32_t> swizzles;
78 // If a vector component is selected dynamically (i.e. indexed with a non-literal index),
79 // dynamicComponent will contain the id of the index.
80 spirv::IdRef dynamicComponent;
81
82 // Type of base expression, before swizzle is applied, after swizzle is applied and after
83 // dynamic component is applied.
84 spirv::IdRef baseTypeId;
85 spirv::IdRef preSwizzleTypeId;
86 spirv::IdRef postSwizzleTypeId;
87 spirv::IdRef postDynamicComponentTypeId;
88
89 // If the OpAccessChain is already generated (done by accessChainCollapse()), this caches the
90 // id.
91 spirv::IdRef accessChainId;
92
93 // Whether all indices are literal. Avoids looping through indices to determine this
94 // information.
95 bool areAllIndicesLiteral = true;
96 // The number of components in the vector, if vector and swizzle is used. This is cached to
97 // avoid a type look up when handling swizzles.
98 uint8_t swizzledVectorComponentCount = 0;
99
100 // SPIR-V type specialization due to the base type. Used to correctly select the SPIR-V type
101 // id when visiting EOpIndex* binary nodes (i.e. reading from or writing to an access chain).
102 // This always corresponds to the specialization specific to the end result of the access chain,
103 // not the base or any intermediary types. For example, a struct nested in a column-major
104 // interface block, with a parent block qualified as row-major would specify row-major here.
105 SpirvTypeSpec typeSpec;
106 };
107
108 // As each node is traversed, it produces data. When visiting back the parent, this data is used to
109 // complete the data of the parent. For example, the children of a function call (i.e. the
110 // arguments) each produce a SPIR-V id corresponding to the result of their expression. The
111 // function call node itself in PostVisit uses those ids to generate the function call instruction.
112 struct NodeData
113 {
114 // An id whose meaning depends on the node. It could be a temporary id holding the result of an
115 // expression, a reference to a variable etc.
116 spirv::IdRef baseId;
117
118 // List of relevant SPIR-V ids accumulated while traversing the children. Meaning depends on
119 // the node, for example a list of parameters to be passed to a function, a set of ids used to
120 // construct an access chain etc.
121 std::vector<SpirvIdOrLiteral> idList;
122
123 // For constructing access chains.
124 AccessChain accessChain;
125 };
126
127 struct FunctionIds
128 {
129 // Id of the function type, return type and parameter types.
130 spirv::IdRef functionTypeId;
131 spirv::IdRef returnTypeId;
132 spirv::IdRefList parameterTypeIds;
133
134 // Id of the function itself.
135 spirv::IdRef functionId;
136 };
137
138 struct BuiltInResultStruct
139 {
140 // Some builtins require a struct result. The struct always has two fields of a scalar or
141 // vector type.
142 TBasicType lsbType;
143 TBasicType msbType;
144 uint32_t lsbPrimarySize;
145 uint32_t msbPrimarySize;
146 };
147
148 struct BuiltInResultStructHash
149 {
operator ()sh::__anonce4844100111::BuiltInResultStructHash150 size_t operator()(const BuiltInResultStruct &key) const
151 {
152 static_assert(sh::EbtLast < 256, "Basic type doesn't fit in uint8_t");
153 ASSERT(key.lsbPrimarySize > 0 && key.lsbPrimarySize <= 4);
154 ASSERT(key.msbPrimarySize > 0 && key.msbPrimarySize <= 4);
155
156 const uint8_t properties[4] = {
157 static_cast<uint8_t>(key.lsbType),
158 static_cast<uint8_t>(key.msbType),
159 static_cast<uint8_t>(key.lsbPrimarySize),
160 static_cast<uint8_t>(key.msbPrimarySize),
161 };
162
163 return angle::ComputeGenericHash(properties, sizeof(properties));
164 }
165 };
166
operator ==(const BuiltInResultStruct & a,const BuiltInResultStruct & b)167 bool operator==(const BuiltInResultStruct &a, const BuiltInResultStruct &b)
168 {
169 return a.lsbType == b.lsbType && a.msbType == b.msbType &&
170 a.lsbPrimarySize == b.lsbPrimarySize && a.msbPrimarySize == b.msbPrimarySize;
171 }
172
IsAccessChainRValue(const AccessChain & accessChain)173 bool IsAccessChainRValue(const AccessChain &accessChain)
174 {
175 return accessChain.storageClass == spv::StorageClassMax;
176 }
177
178 // A traverser that generates SPIR-V as it walks the AST.
179 class OutputSPIRVTraverser : public TIntermTraverser
180 {
181 public:
182 OutputSPIRVTraverser(TCompiler *compiler,
183 const ShCompileOptions &compileOptions,
184 const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
185 uint32_t firstUnusedSpirvId);
186 ~OutputSPIRVTraverser() override;
187
188 spirv::Blob getSpirv();
189
190 protected:
191 void visitSymbol(TIntermSymbol *node) override;
192 void visitConstantUnion(TIntermConstantUnion *node) override;
193 bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
194 bool visitBinary(Visit visit, TIntermBinary *node) override;
195 bool visitUnary(Visit visit, TIntermUnary *node) override;
196 bool visitTernary(Visit visit, TIntermTernary *node) override;
197 bool visitIfElse(Visit visit, TIntermIfElse *node) override;
198 bool visitSwitch(Visit visit, TIntermSwitch *node) override;
199 bool visitCase(Visit visit, TIntermCase *node) override;
200 void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
201 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
202 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
203 bool visitBlock(Visit visit, TIntermBlock *node) override;
204 bool visitGlobalQualifierDeclaration(Visit visit,
205 TIntermGlobalQualifierDeclaration *node) override;
206 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
207 bool visitLoop(Visit visit, TIntermLoop *node) override;
208 bool visitBranch(Visit visit, TIntermBranch *node) override;
209 void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
210
211 private:
212 spirv::IdRef getSymbolIdAndStorageClass(const TSymbol *symbol,
213 const TType &type,
214 spv::StorageClass *storageClass);
215
216 // Access chain handling.
217
218 // Called before pushing indices to access chain to adjust |typeSpec| (which is then used to
219 // determine the typeId passed to |accessChainPush*|).
220 void accessChainOnPush(NodeData *data, const TType &parentType, size_t index);
221 void accessChainPush(NodeData *data, spirv::IdRef index, spirv::IdRef typeId) const;
222 void accessChainPushLiteral(NodeData *data,
223 spirv::LiteralInteger index,
224 spirv::IdRef typeId) const;
225 void accessChainPushSwizzle(NodeData *data,
226 const TVector<int> &swizzle,
227 spirv::IdRef typeId,
228 uint8_t componentCount) const;
229 void accessChainPushDynamicComponent(NodeData *data, spirv::IdRef index, spirv::IdRef typeId);
230 spirv::IdRef accessChainCollapse(NodeData *data);
231 spirv::IdRef accessChainLoad(NodeData *data,
232 const TType &valueType,
233 spirv::IdRef *resultTypeIdOut);
234 void accessChainStore(NodeData *data, spirv::IdRef value, const TType &valueType);
235
236 // Access chain helpers.
237 void makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut);
238 void makeAccessChainLiteralList(NodeData *data, spirv::LiteralIntegerList *literalsOut);
239 spirv::IdRef getAccessChainTypeId(NodeData *data);
240
241 // Node data handling.
242 void nodeDataInitLValue(NodeData *data,
243 spirv::IdRef baseId,
244 spirv::IdRef typeId,
245 spv::StorageClass storageClass,
246 const SpirvTypeSpec &typeSpec) const;
247 void nodeDataInitRValue(NodeData *data, spirv::IdRef baseId, spirv::IdRef typeId) const;
248
249 void declareConst(TIntermDeclaration *decl);
250 void declareSpecConst(TIntermDeclaration *decl);
251 spirv::IdRef createConstant(const TType &type,
252 TBasicType expectedBasicType,
253 const TConstantUnion *constUnion,
254 bool isConstantNullValue);
255 spirv::IdRef createComplexConstant(const TType &type,
256 spirv::IdRef typeId,
257 const spirv::IdRefList ¶meters);
258 spirv::IdRef createConstructor(TIntermAggregate *node, spirv::IdRef typeId);
259 spirv::IdRef createArrayOrStructConstructor(TIntermAggregate *node,
260 spirv::IdRef typeId,
261 const spirv::IdRefList ¶meters);
262 spirv::IdRef createConstructorScalarFromNonScalar(TIntermAggregate *node,
263 spirv::IdRef typeId,
264 const spirv::IdRefList ¶meters);
265 spirv::IdRef createConstructorVectorFromScalar(const TType ¶meterType,
266 const TType &expectedType,
267 spirv::IdRef typeId,
268 const spirv::IdRefList ¶meters);
269 spirv::IdRef createConstructorVectorFromMatrix(TIntermAggregate *node,
270 spirv::IdRef typeId,
271 const spirv::IdRefList ¶meters);
272 spirv::IdRef createConstructorVectorFromMultiple(TIntermAggregate *node,
273 spirv::IdRef typeId,
274 const spirv::IdRefList ¶meters);
275 spirv::IdRef createConstructorMatrixFromScalar(TIntermAggregate *node,
276 spirv::IdRef typeId,
277 const spirv::IdRefList ¶meters);
278 spirv::IdRef createConstructorMatrixFromVectors(TIntermAggregate *node,
279 spirv::IdRef typeId,
280 const spirv::IdRefList ¶meters);
281 spirv::IdRef createConstructorMatrixFromMatrix(TIntermAggregate *node,
282 spirv::IdRef typeId,
283 const spirv::IdRefList ¶meters);
284 // Load N values where N is the number of node's children. In some cases, the last M values are
285 // lvalues which should be skipped.
286 spirv::IdRefList loadAllParams(TIntermOperator *node,
287 size_t skipCount,
288 spirv::IdRefList *paramTypeIds);
289 void extractComponents(TIntermAggregate *node,
290 size_t componentCount,
291 const spirv::IdRefList ¶meters,
292 spirv::IdRefList *extractedComponentsOut);
293
294 void startShortCircuit(TIntermBinary *node);
295 spirv::IdRef endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId);
296
297 spirv::IdRef visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId);
298 spirv::IdRef createCompare(TIntermOperator *node, spirv::IdRef resultTypeId);
299 spirv::IdRef createAtomicBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
300 spirv::IdRef createImageTextureBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
301 spirv::IdRef createSubpassLoadBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
302 spirv::IdRef createInterpolate(TIntermOperator *node, spirv::IdRef resultTypeId);
303
304 spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
305
306 void visitArrayLength(TIntermUnary *node);
307
308 // Cast between types. There are two kinds of casts:
309 //
310 // - A constructor can cast between basic types, for example vec4(someInt).
311 // - Assignments, constructors, function calls etc may copy an array or struct between different
312 // block storages, invariance etc (which due to their decorations generate different SPIR-V
313 // types). For example:
314 //
315 // layout(std140) uniform U { invariant Struct s; } u; ... Struct s2 = u.s;
316 //
317 spirv::IdRef castBasicType(spirv::IdRef value,
318 const TType &valueType,
319 const TType &expectedType,
320 spirv::IdRef *resultTypeIdOut);
321 spirv::IdRef cast(spirv::IdRef value,
322 const TType &valueType,
323 const SpirvTypeSpec &valueTypeSpec,
324 const SpirvTypeSpec &expectedTypeSpec,
325 spirv::IdRef *resultTypeIdOut);
326
327 // Given a list of parameters to an operator, extend the scalars to match the vectors. GLSL
328 // frequently has operators that mix vectors and scalars, while SPIR-V usually applies the
329 // operations per component (requiring the scalars to turn into a vector).
330 void extendScalarParamsToVector(TIntermOperator *node,
331 spirv::IdRef resultTypeId,
332 spirv::IdRefList *parameters);
333
334 // Helper to reduce vector == and != with OpAll and OpAny respectively. If multiple ids are
335 // given, either OpLogicalAnd or OpLogicalOr is used (if two operands) or a bool vector is
336 // constructed and OpAll and OpAny used.
337 spirv::IdRef reduceBoolVector(TOperator op,
338 const spirv::IdRefList &valueIds,
339 spirv::IdRef typeId,
340 const SpirvDecorations &decorations);
341 // Helper to implement == and !=, supporting vectors, matrices, structs and arrays.
342 void createCompareImpl(TOperator op,
343 const TType &operandType,
344 spirv::IdRef resultTypeId,
345 spirv::IdRef leftId,
346 spirv::IdRef rightId,
347 const SpirvDecorations &operandDecorations,
348 const SpirvDecorations &resultDecorations,
349 spirv::LiteralIntegerList *currentAccessChain,
350 spirv::IdRefList *intermediateResultsOut);
351
352 // For some builtins, SPIR-V outputs two values in a struct. This function defines such a
353 // struct if not already defined.
354 spirv::IdRef makeBuiltInOutputStructType(TIntermOperator *node, size_t lvalueCount);
355 // Once the builtin instruction is generated, the two return values are extracted from the
356 // struct. These are written to the return value (if any) and the out parameters.
357 void storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator *node,
358 size_t lvalueCount,
359 spirv::IdRef structValue,
360 spirv::IdRef returnValue,
361 spirv::IdRef returnValueType);
362 void storeBuiltInStructOutputInParamHelper(NodeData *data,
363 TIntermTyped *param,
364 spirv::IdRef structValue,
365 uint32_t fieldIndex);
366
367 void markVertexOutputOnShaderEnd();
368 void markVertexOutputOnEmitVertex();
369
370 TCompiler *mCompiler;
371 ANGLE_MAYBE_UNUSED_PRIVATE_FIELD const ShCompileOptions &mCompileOptions;
372
373 SPIRVBuilder mBuilder;
374
375 // Traversal state. Nodes generally push() once to this stack on PreVisit. On InVisit and
376 // PostVisit, they pop() once (data corresponding to the result of the child) and accumulate it
377 // in back() (data corresponding to the node itself). On PostVisit, code is generated.
378 std::vector<NodeData> mNodeData;
379
380 // A map of TSymbol to its SPIR-V id. This could be a:
381 //
382 // - TVariable, or
383 // - TInterfaceBlock: because TIntermSymbols referencing a field of an unnamed interface block
384 // don't reference the TVariable that defines the struct, but the TInterfaceBlock itself.
385 angle::HashMap<const TSymbol *, spirv::IdRef> mSymbolIdMap;
386
387 // A map of TFunction to its various SPIR-V ids.
388 angle::HashMap<const TFunction *, FunctionIds> mFunctionIdMap;
389
390 // A map of internally defined structs used to capture result of some SPIR-V instructions.
391 angle::HashMap<BuiltInResultStruct, spirv::IdRef, BuiltInResultStructHash>
392 mBuiltInResultStructMap;
393
394 // Whether the current symbol being visited is being declared.
395 bool mIsSymbolBeingDeclared = false;
396
397 // What is the id of the current function being generated.
398 spirv::IdRef mCurrentFunctionId;
399 };
400
GetStorageClass(const ShCompileOptions & compileOptions,const TType & type,GLenum shaderType)401 spv::StorageClass GetStorageClass(const ShCompileOptions &compileOptions,
402 const TType &type,
403 GLenum shaderType)
404 {
405 // Opaque uniforms (samplers, images and subpass inputs) have the UniformConstant storage class
406 if (IsOpaqueType(type.getBasicType()))
407 {
408 return spv::StorageClassUniformConstant;
409 }
410
411 const TQualifier qualifier = type.getQualifier();
412
413 // Input varying and IO blocks have the Input storage class
414 if (IsShaderIn(qualifier))
415 {
416 return spv::StorageClassInput;
417 }
418
419 // Output varying and IO blocks have the Input storage class
420 if (IsShaderOut(qualifier))
421 {
422 return spv::StorageClassOutput;
423 }
424
425 switch (qualifier)
426 {
427 case EvqShared:
428 // Compute shader shared memory has the Workgroup storage class
429 return spv::StorageClassWorkgroup;
430
431 case EvqGlobal:
432 case EvqConst:
433 // Global variables have the Private class. Complex constant variables that are not
434 // folded are also defined globally.
435 return spv::StorageClassPrivate;
436
437 case EvqTemporary:
438 case EvqParamIn:
439 case EvqParamOut:
440 case EvqParamInOut:
441 // Function-local variables have the Function class
442 return spv::StorageClassFunction;
443
444 case EvqVertexID:
445 case EvqInstanceID:
446 case EvqFragCoord:
447 case EvqFrontFacing:
448 case EvqPointCoord:
449 case EvqSampleID:
450 case EvqSamplePosition:
451 case EvqSampleMaskIn:
452 case EvqPatchVerticesIn:
453 case EvqTessCoord:
454 case EvqPrimitiveIDIn:
455 case EvqInvocationID:
456 case EvqHelperInvocation:
457 case EvqNumWorkGroups:
458 case EvqWorkGroupID:
459 case EvqLocalInvocationID:
460 case EvqGlobalInvocationID:
461 case EvqLocalInvocationIndex:
462 case EvqViewIDOVR:
463 case EvqLayerIn:
464 case EvqLastFragColor:
465 case EvqLastFragDepth:
466 case EvqLastFragStencil:
467 return spv::StorageClassInput;
468
469 case EvqPosition:
470 case EvqPointSize:
471 case EvqFragDepth:
472 case EvqSampleMask:
473 case EvqLayerOut:
474 return spv::StorageClassOutput;
475
476 case EvqClipDistance:
477 case EvqCullDistance:
478 // gl_Clip/CullDistance (not accessed through gl_in/gl_out) are inputs in FS and outputs
479 // otherwise.
480 return shaderType == GL_FRAGMENT_SHADER ? spv::StorageClassInput
481 : spv::StorageClassOutput;
482
483 case EvqTessLevelOuter:
484 case EvqTessLevelInner:
485 // gl_TessLevelOuter/Inner are outputs in TCS and inputs in TES.
486 return shaderType == GL_TESS_CONTROL_SHADER_EXT ? spv::StorageClassOutput
487 : spv::StorageClassInput;
488
489 case EvqPrimitiveID:
490 // gl_PrimitiveID is output in GS and input in TCS, TES and FS.
491 return shaderType == GL_GEOMETRY_SHADER ? spv::StorageClassOutput
492 : spv::StorageClassInput;
493
494 default:
495 // Uniform buffers have the Uniform storage class. Storage buffers have the Uniform
496 // storage class in SPIR-V 1.3, and the StorageBuffer storage class in SPIR-V 1.4.
497 // Default uniforms are gathered in a uniform block as well. Push constants use the
498 // PushConstant storage class instead.
499 ASSERT(type.getInterfaceBlock() != nullptr || qualifier == EvqUniform);
500 // I/O blocks must have already been classified as input or output above.
501 ASSERT(!IsShaderIoBlock(qualifier));
502
503 if (type.getLayoutQualifier().pushConstant)
504 {
505 ASSERT(type.getInterfaceBlock() != nullptr);
506 return spv::StorageClassPushConstant;
507 }
508 return compileOptions.emitSPIRV14 && qualifier == EvqBuffer
509 ? spv::StorageClassStorageBuffer
510 : spv::StorageClassUniform;
511 }
512 }
513
OutputSPIRVTraverser(TCompiler * compiler,const ShCompileOptions & compileOptions,const angle::HashMap<int,uint32_t> & uniqueToSpirvIdMap,uint32_t firstUnusedSpirvId)514 OutputSPIRVTraverser::OutputSPIRVTraverser(TCompiler *compiler,
515 const ShCompileOptions &compileOptions,
516 const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
517 uint32_t firstUnusedSpirvId)
518 : TIntermTraverser(true, true, true, &compiler->getSymbolTable()),
519 mCompiler(compiler),
520 mCompileOptions(compileOptions),
521 mBuilder(compiler, compileOptions, uniqueToSpirvIdMap, firstUnusedSpirvId)
522 {}
523
~OutputSPIRVTraverser()524 OutputSPIRVTraverser::~OutputSPIRVTraverser()
525 {
526 ASSERT(mNodeData.empty());
527 }
528
getSymbolIdAndStorageClass(const TSymbol * symbol,const TType & type,spv::StorageClass * storageClass)529 spirv::IdRef OutputSPIRVTraverser::getSymbolIdAndStorageClass(const TSymbol *symbol,
530 const TType &type,
531 spv::StorageClass *storageClass)
532 {
533 *storageClass = GetStorageClass(mCompileOptions, type, mCompiler->getShaderType());
534 auto iter = mSymbolIdMap.find(symbol);
535 if (iter != mSymbolIdMap.end())
536 {
537 return iter->second;
538 }
539
540 // This must be an implicitly defined variable, define it now.
541 const char *name = nullptr;
542 spv::BuiltIn builtInDecoration = spv::BuiltInMax;
543 const TSymbolUniqueId *uniqueId = nullptr;
544
545 switch (type.getQualifier())
546 {
547 // Vertex shader built-ins
548 case EvqVertexID:
549 name = "gl_VertexIndex";
550 builtInDecoration = spv::BuiltInVertexIndex;
551 break;
552 case EvqInstanceID:
553 name = "gl_InstanceIndex";
554 builtInDecoration = spv::BuiltInInstanceIndex;
555 break;
556
557 // Fragment shader built-ins
558 case EvqFragCoord:
559 name = "gl_FragCoord";
560 builtInDecoration = spv::BuiltInFragCoord;
561 break;
562 case EvqFrontFacing:
563 name = "gl_FrontFacing";
564 builtInDecoration = spv::BuiltInFrontFacing;
565 break;
566 case EvqPointCoord:
567 name = "gl_PointCoord";
568 builtInDecoration = spv::BuiltInPointCoord;
569 break;
570 case EvqFragDepth:
571 name = "gl_FragDepth";
572 builtInDecoration = spv::BuiltInFragDepth;
573 mBuilder.addExecutionMode(spv::ExecutionModeDepthReplacing);
574 switch (type.getLayoutQualifier().depth)
575 {
576 case EdGreater:
577 mBuilder.addExecutionMode(spv::ExecutionModeDepthGreater);
578 break;
579 case EdLess:
580 mBuilder.addExecutionMode(spv::ExecutionModeDepthLess);
581 break;
582 case EdUnchanged:
583 mBuilder.addExecutionMode(spv::ExecutionModeDepthUnchanged);
584 break;
585 default:
586 break;
587 }
588 break;
589 case EvqSampleMask:
590 name = "gl_SampleMask";
591 builtInDecoration = spv::BuiltInSampleMask;
592 break;
593 case EvqSampleMaskIn:
594 name = "gl_SampleMaskIn";
595 builtInDecoration = spv::BuiltInSampleMask;
596 break;
597 case EvqSampleID:
598 name = "gl_SampleID";
599 builtInDecoration = spv::BuiltInSampleId;
600 mBuilder.addCapability(spv::CapabilitySampleRateShading);
601 uniqueId = &symbol->uniqueId();
602 break;
603 case EvqSamplePosition:
604 name = "gl_SamplePosition";
605 builtInDecoration = spv::BuiltInSamplePosition;
606 mBuilder.addCapability(spv::CapabilitySampleRateShading);
607 break;
608 case EvqClipDistance:
609 name = "gl_ClipDistance";
610 builtInDecoration = spv::BuiltInClipDistance;
611 mBuilder.addCapability(spv::CapabilityClipDistance);
612 break;
613 case EvqCullDistance:
614 name = "gl_CullDistance";
615 builtInDecoration = spv::BuiltInCullDistance;
616 mBuilder.addCapability(spv::CapabilityCullDistance);
617 break;
618 case EvqHelperInvocation:
619 name = "gl_HelperInvocation";
620 builtInDecoration = spv::BuiltInHelperInvocation;
621 break;
622
623 // Tessellation built-ins
624 case EvqPatchVerticesIn:
625 name = "gl_PatchVerticesIn";
626 builtInDecoration = spv::BuiltInPatchVertices;
627 break;
628 case EvqTessLevelOuter:
629 name = "gl_TessLevelOuter";
630 builtInDecoration = spv::BuiltInTessLevelOuter;
631 break;
632 case EvqTessLevelInner:
633 name = "gl_TessLevelInner";
634 builtInDecoration = spv::BuiltInTessLevelInner;
635 break;
636 case EvqTessCoord:
637 name = "gl_TessCoord";
638 builtInDecoration = spv::BuiltInTessCoord;
639 break;
640
641 // Shared geometry and tessellation built-ins
642 case EvqInvocationID:
643 name = "gl_InvocationID";
644 builtInDecoration = spv::BuiltInInvocationId;
645 break;
646 case EvqPrimitiveID:
647 name = "gl_PrimitiveID";
648 builtInDecoration = spv::BuiltInPrimitiveId;
649
650 // In fragment shader, add the Geometry capability.
651 if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
652 {
653 mBuilder.addCapability(spv::CapabilityGeometry);
654 }
655
656 break;
657
658 // Geometry shader built-ins
659 case EvqPrimitiveIDIn:
660 name = "gl_PrimitiveIDIn";
661 builtInDecoration = spv::BuiltInPrimitiveId;
662 break;
663 case EvqLayerOut:
664 case EvqLayerIn:
665 name = "gl_Layer";
666 builtInDecoration = spv::BuiltInLayer;
667
668 // gl_Layer requires the Geometry capability, even in fragment shaders.
669 mBuilder.addCapability(spv::CapabilityGeometry);
670
671 break;
672
673 // Compute shader built-ins
674 case EvqNumWorkGroups:
675 name = "gl_NumWorkGroups";
676 builtInDecoration = spv::BuiltInNumWorkgroups;
677 break;
678 case EvqWorkGroupID:
679 name = "gl_WorkGroupID";
680 builtInDecoration = spv::BuiltInWorkgroupId;
681 break;
682 case EvqLocalInvocationID:
683 name = "gl_LocalInvocationID";
684 builtInDecoration = spv::BuiltInLocalInvocationId;
685 break;
686 case EvqGlobalInvocationID:
687 name = "gl_GlobalInvocationID";
688 builtInDecoration = spv::BuiltInGlobalInvocationId;
689 break;
690 case EvqLocalInvocationIndex:
691 name = "gl_LocalInvocationIndex";
692 builtInDecoration = spv::BuiltInLocalInvocationIndex;
693 break;
694
695 // Extensions
696 case EvqViewIDOVR:
697 name = "gl_ViewID_OVR";
698 builtInDecoration = spv::BuiltInViewIndex;
699 mBuilder.addCapability(spv::CapabilityMultiView);
700 mBuilder.addExtension(SPIRVExtensions::MultiviewOVR);
701 break;
702
703 default:
704 UNREACHABLE();
705 }
706
707 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
708 const spirv::IdRef varId = mBuilder.declareVariable(
709 typeId, *storageClass, mBuilder.getDecorations(type), nullptr, name, uniqueId);
710
711 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationBuiltIn,
712 {spirv::LiteralInteger(builtInDecoration)});
713
714 // Additionally:
715 //
716 // - decorate int inputs in FS with Flat (gl_Layer, gl_SampleID, gl_PrimitiveID, gl_ViewID_OVR).
717 // - decorate gl_TessLevel* with Patch.
718 switch (type.getQualifier())
719 {
720 case EvqLayerIn:
721 case EvqSampleID:
722 case EvqPrimitiveID:
723 case EvqViewIDOVR:
724 if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
725 {
726 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationFlat,
727 {});
728 }
729 break;
730 case EvqTessLevelInner:
731 case EvqTessLevelOuter:
732 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationPatch, {});
733 break;
734 default:
735 break;
736 }
737
738 mSymbolIdMap.insert({symbol, varId});
739 return varId;
740 }
741
nodeDataInitLValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId,spv::StorageClass storageClass,const SpirvTypeSpec & typeSpec) const742 void OutputSPIRVTraverser::nodeDataInitLValue(NodeData *data,
743 spirv::IdRef baseId,
744 spirv::IdRef typeId,
745 spv::StorageClass storageClass,
746 const SpirvTypeSpec &typeSpec) const
747 {
748 *data = {};
749
750 // Initialize the access chain as an lvalue. Useful when an access chain is resolved, but needs
751 // to be replaced by a reference to a temporary variable holding the result.
752 data->baseId = baseId;
753 data->accessChain.baseTypeId = typeId;
754 data->accessChain.preSwizzleTypeId = typeId;
755 data->accessChain.storageClass = storageClass;
756 data->accessChain.typeSpec = typeSpec;
757 }
758
nodeDataInitRValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId) const759 void OutputSPIRVTraverser::nodeDataInitRValue(NodeData *data,
760 spirv::IdRef baseId,
761 spirv::IdRef typeId) const
762 {
763 *data = {};
764
765 // Initialize the access chain as an rvalue. Useful when an access chain is resolved, and needs
766 // to be replaced by a reference to it.
767 data->baseId = baseId;
768 data->accessChain.baseTypeId = typeId;
769 data->accessChain.preSwizzleTypeId = typeId;
770 }
771
accessChainOnPush(NodeData * data,const TType & parentType,size_t index)772 void OutputSPIRVTraverser::accessChainOnPush(NodeData *data, const TType &parentType, size_t index)
773 {
774 AccessChain &accessChain = data->accessChain;
775
776 // Adjust |typeSpec| based on the type (which implies what the index does; select an array
777 // element, a block field etc). Index is only meaningful for selecting block fields.
778 if (parentType.isArray())
779 {
780 accessChain.typeSpec.onArrayElementSelection(
781 (parentType.getStruct() != nullptr || parentType.isInterfaceBlock()),
782 parentType.isArrayOfArrays());
783 }
784 else if (parentType.isInterfaceBlock() || parentType.getStruct() != nullptr)
785 {
786 const TFieldListCollection *block = parentType.getInterfaceBlock();
787 if (!parentType.isInterfaceBlock())
788 {
789 block = parentType.getStruct();
790 }
791
792 const TType &fieldType = *block->fields()[index]->type();
793 accessChain.typeSpec.onBlockFieldSelection(fieldType);
794 }
795 else if (parentType.isMatrix())
796 {
797 accessChain.typeSpec.onMatrixColumnSelection();
798 }
799 else
800 {
801 ASSERT(parentType.isVector());
802 accessChain.typeSpec.onVectorComponentSelection();
803 }
804 }
805
accessChainPush(NodeData * data,spirv::IdRef index,spirv::IdRef typeId) const806 void OutputSPIRVTraverser::accessChainPush(NodeData *data,
807 spirv::IdRef index,
808 spirv::IdRef typeId) const
809 {
810 // Simply add the index to the chain of indices.
811 data->idList.emplace_back(index);
812 data->accessChain.areAllIndicesLiteral = false;
813 data->accessChain.preSwizzleTypeId = typeId;
814 }
815
accessChainPushLiteral(NodeData * data,spirv::LiteralInteger index,spirv::IdRef typeId) const816 void OutputSPIRVTraverser::accessChainPushLiteral(NodeData *data,
817 spirv::LiteralInteger index,
818 spirv::IdRef typeId) const
819 {
820 // Add the literal integer in the chain of indices. Since this is an id list, fake it as an id.
821 data->idList.emplace_back(index);
822 data->accessChain.preSwizzleTypeId = typeId;
823
824 // Literal index of a swizzle must be already folded in the AST.
825 ASSERT(data->accessChain.swizzles.empty());
826 }
827
accessChainPushSwizzle(NodeData * data,const TVector<int> & swizzle,spirv::IdRef typeId,uint8_t componentCount) const828 void OutputSPIRVTraverser::accessChainPushSwizzle(NodeData *data,
829 const TVector<int> &swizzle,
830 spirv::IdRef typeId,
831 uint8_t componentCount) const
832 {
833 AccessChain &accessChain = data->accessChain;
834
835 // Record the swizzle as multi-component swizzles require special handling. When loading
836 // through the access chain, the swizzle is applied after loading the vector first (see
837 // |accessChainLoad()|). When storing through the access chain, the whole vector is loaded,
838 // swizzled components overwritten and the whole vector written back (see |accessChainStore()|).
839 ASSERT(accessChain.swizzles.empty());
840
841 if (swizzle.size() == 1)
842 {
843 // If this swizzle is selecting a single component, fold it into the access chain.
844 accessChainPushLiteral(data, spirv::LiteralInteger(swizzle[0]), typeId);
845 }
846 else
847 {
848 // Otherwise keep them separate.
849 accessChain.swizzles.insert(accessChain.swizzles.end(), swizzle.begin(), swizzle.end());
850 accessChain.postSwizzleTypeId = typeId;
851 accessChain.swizzledVectorComponentCount = componentCount;
852 }
853 }
854
accessChainPushDynamicComponent(NodeData * data,spirv::IdRef index,spirv::IdRef typeId)855 void OutputSPIRVTraverser::accessChainPushDynamicComponent(NodeData *data,
856 spirv::IdRef index,
857 spirv::IdRef typeId)
858 {
859 AccessChain &accessChain = data->accessChain;
860
861 // Record the index used to dynamically select a component of a vector.
862 ASSERT(!accessChain.dynamicComponent.valid());
863
864 if (IsAccessChainRValue(accessChain) && accessChain.areAllIndicesLiteral)
865 {
866 // If the access chain is an rvalue with all-literal indices, keep this index separate so
867 // that OpCompositeExtract can be used for the access chain up to this index.
868 accessChain.dynamicComponent = index;
869 accessChain.postDynamicComponentTypeId = typeId;
870 return;
871 }
872
873 if (!accessChain.swizzles.empty())
874 {
875 // Otherwise if there's a swizzle, fold the swizzle and dynamic component selection into a
876 // single dynamic component selection.
877 ASSERT(accessChain.swizzles.size() > 1);
878
879 // Create a vector constant from the swizzles.
880 spirv::IdRefList swizzleIds;
881 for (uint32_t component : accessChain.swizzles)
882 {
883 swizzleIds.push_back(mBuilder.getUintConstant(component));
884 }
885
886 const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
887 const spirv::IdRef uvecTypeId = mBuilder.getBasicTypeId(EbtUInt, swizzleIds.size());
888
889 const spirv::IdRef swizzlesId = mBuilder.getNewId({});
890 spirv::WriteConstantComposite(mBuilder.getSpirvTypeAndConstantDecls(), uvecTypeId,
891 swizzlesId, swizzleIds);
892
893 // Index that vector constant with the dynamic index. For example, vec.ywxz[i] becomes the
894 // constant {1, 3, 0, 2} indexed with i, and that index used on vec.
895 const spirv::IdRef newIndex = mBuilder.getNewId({});
896 spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId,
897 newIndex, swizzlesId, index);
898
899 index = newIndex;
900 accessChain.swizzles.clear();
901 }
902
903 // Fold it into the access chain.
904 accessChainPush(data, index, typeId);
905 }
906
accessChainCollapse(NodeData * data)907 spirv::IdRef OutputSPIRVTraverser::accessChainCollapse(NodeData *data)
908 {
909 AccessChain &accessChain = data->accessChain;
910
911 ASSERT(accessChain.storageClass != spv::StorageClassMax);
912
913 if (accessChain.accessChainId.valid())
914 {
915 return accessChain.accessChainId;
916 }
917
918 // If there are no indices, the baseId is where access is done to/from.
919 if (data->idList.empty())
920 {
921 accessChain.accessChainId = data->baseId;
922 return accessChain.accessChainId;
923 }
924
925 // Otherwise create an OpAccessChain instruction. Swizzle handling is special as it selects
926 // multiple components, and is done differently for load and store.
927 spirv::IdRefList indexIds;
928 makeAccessChainIdList(data, &indexIds);
929
930 const spirv::IdRef typePointerId =
931 mBuilder.getTypePointerId(accessChain.preSwizzleTypeId, accessChain.storageClass);
932
933 accessChain.accessChainId = mBuilder.getNewId({});
934 spirv::WriteAccessChain(mBuilder.getSpirvCurrentFunctionBlock(), typePointerId,
935 accessChain.accessChainId, data->baseId, indexIds);
936
937 return accessChain.accessChainId;
938 }
939
accessChainLoad(NodeData * data,const TType & valueType,spirv::IdRef * resultTypeIdOut)940 spirv::IdRef OutputSPIRVTraverser::accessChainLoad(NodeData *data,
941 const TType &valueType,
942 spirv::IdRef *resultTypeIdOut)
943 {
944 const SpirvDecorations &decorations = mBuilder.getDecorations(valueType);
945
946 // Loading through the access chain can generate different instructions based on whether it's an
947 // rvalue, the indices are literal, there's a swizzle etc.
948 //
949 // - If rvalue:
950 // * With indices:
951 // + All literal: OpCompositeExtract which uses literal integers to access the rvalue.
952 // + Otherwise: Can't use OpAccessChain on an rvalue, so create a temporary variable, OpStore
953 // the rvalue into it, then use OpAccessChain and OpLoad to load from it.
954 // * Without indices: Take the base id.
955 // - If lvalue:
956 // * With indices: Use OpAccessChain and OpLoad
957 // * Without indices: Use OpLoad
958 // - With swizzle: Use OpVectorShuffle on the result of the previous step
959 // - With dynamic component: Use OpVectorExtractDynamic on the result of the previous step
960
961 AccessChain &accessChain = data->accessChain;
962
963 if (resultTypeIdOut)
964 {
965 *resultTypeIdOut = getAccessChainTypeId(data);
966 }
967
968 spirv::IdRef loadResult = data->baseId;
969
970 if (IsAccessChainRValue(accessChain))
971 {
972 if (data->idList.size() > 0)
973 {
974 if (accessChain.areAllIndicesLiteral)
975 {
976 // Use OpCompositeExtract on an rvalue with all literal indices.
977 spirv::LiteralIntegerList indexList;
978 makeAccessChainLiteralList(data, &indexList);
979
980 const spirv::IdRef result = mBuilder.getNewId(decorations);
981 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
982 accessChain.preSwizzleTypeId, result, loadResult,
983 indexList);
984 loadResult = result;
985 }
986 else
987 {
988 // Create a temp variable to hold the rvalue so an access chain can be made on it.
989 const spirv::IdRef tempVar =
990 mBuilder.declareVariable(accessChain.baseTypeId, spv::StorageClassFunction,
991 decorations, nullptr, "indexable", nullptr);
992
993 // Write the rvalue into the temp variable
994 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVar, loadResult,
995 nullptr);
996
997 // Make the temp variable the source of the access chain.
998 data->baseId = tempVar;
999 data->accessChain.storageClass = spv::StorageClassFunction;
1000
1001 // Load from the temp variable.
1002 const spirv::IdRef accessChainId = accessChainCollapse(data);
1003 loadResult = mBuilder.getNewId(decorations);
1004 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(),
1005 accessChain.preSwizzleTypeId, loadResult, accessChainId, nullptr);
1006 }
1007 }
1008 }
1009 else
1010 {
1011 // Load from the access chain.
1012 const spirv::IdRef accessChainId = accessChainCollapse(data);
1013 loadResult = mBuilder.getNewId(decorations);
1014 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1015 loadResult, accessChainId, nullptr);
1016 }
1017
1018 if (!accessChain.swizzles.empty())
1019 {
1020 // Single-component swizzles are already folded into the index list.
1021 ASSERT(accessChain.swizzles.size() > 1);
1022
1023 // Take the loaded value and use OpVectorShuffle to create the swizzle.
1024 spirv::LiteralIntegerList swizzleList;
1025 for (uint32_t component : accessChain.swizzles)
1026 {
1027 swizzleList.push_back(spirv::LiteralInteger(component));
1028 }
1029
1030 const spirv::IdRef result = mBuilder.getNewId(decorations);
1031 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1032 accessChain.postSwizzleTypeId, result, loadResult, loadResult,
1033 swizzleList);
1034 loadResult = result;
1035 }
1036
1037 if (accessChain.dynamicComponent.valid())
1038 {
1039 // Use OpVectorExtractDynamic to select the component.
1040 const spirv::IdRef result = mBuilder.getNewId(decorations);
1041 spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(),
1042 accessChain.postDynamicComponentTypeId, result, loadResult,
1043 accessChain.dynamicComponent);
1044 loadResult = result;
1045 }
1046
1047 // Upon loading values, cast them to the default SPIR-V variant.
1048 const spirv::IdRef castResult =
1049 cast(loadResult, valueType, accessChain.typeSpec, {}, resultTypeIdOut);
1050
1051 return castResult;
1052 }
1053
accessChainStore(NodeData * data,spirv::IdRef value,const TType & valueType)1054 void OutputSPIRVTraverser::accessChainStore(NodeData *data,
1055 spirv::IdRef value,
1056 const TType &valueType)
1057 {
1058 // Storing through the access chain can generate different instructions based on whether the
1059 // there's a swizzle.
1060 //
1061 // - Without swizzle: Use OpAccessChain and OpStore
1062 // - With swizzle: Use OpAccessChain and OpLoad to load the vector, then use OpVectorShuffle to
1063 // replace the components being overwritten. Finally, use OpStore to write the result back.
1064
1065 AccessChain &accessChain = data->accessChain;
1066
1067 // Single-component swizzles are already folded into the indices.
1068 ASSERT(accessChain.swizzles.size() != 1);
1069 // Since store can only happen through lvalues, it's impossible to have a dynamic component as
1070 // that always gets folded into the indices except for rvalues.
1071 ASSERT(!accessChain.dynamicComponent.valid());
1072
1073 const spirv::IdRef accessChainId = accessChainCollapse(data);
1074
1075 // Store through the access chain. The values are always cast to the default SPIR-V type
1076 // variant when loaded from memory and operated on as such. When storing, we need to cast the
1077 // result to the variant specified by the access chain.
1078 value = cast(value, valueType, {}, accessChain.typeSpec, nullptr);
1079
1080 if (!accessChain.swizzles.empty())
1081 {
1082 // Load the vector before the swizzle.
1083 const spirv::IdRef loadResult = mBuilder.getNewId({});
1084 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1085 loadResult, accessChainId, nullptr);
1086
1087 // Overwrite the components being written. This is done by first creating an identity
1088 // swizzle, then replacing the components being written with a swizzle from the value. For
1089 // example, take the following:
1090 //
1091 // vec4 v;
1092 // v.zx = u;
1093 //
1094 // The OpVectorShuffle instruction takes two vectors (v and u) and selects components from
1095 // each (in this example, swizzles [0, 3] select from v and [4, 7] select from u). This
1096 // algorithm first creates the identity swizzles {0, 1, 2, 3}, then replaces z and x (the
1097 // 0th and 2nd element) with swizzles from u (4 + {0, 1}) to get the result
1098 // {4+1, 1, 4+0, 3}.
1099
1100 spirv::LiteralIntegerList swizzleList;
1101 for (uint32_t component = 0; component < accessChain.swizzledVectorComponentCount;
1102 ++component)
1103 {
1104 swizzleList.push_back(spirv::LiteralInteger(component));
1105 }
1106 uint32_t srcComponent = 0;
1107 for (uint32_t dstComponent : accessChain.swizzles)
1108 {
1109 swizzleList[dstComponent] =
1110 spirv::LiteralInteger(accessChain.swizzledVectorComponentCount + srcComponent);
1111 ++srcComponent;
1112 }
1113
1114 // Use the generated swizzle to select components from the loaded vector and the value to be
1115 // written. Use the final result as the value to be written to the vector.
1116 const spirv::IdRef result = mBuilder.getNewId({});
1117 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1118 accessChain.preSwizzleTypeId, result, loadResult, value,
1119 swizzleList);
1120 value = result;
1121 }
1122
1123 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), accessChainId, value, nullptr);
1124 }
1125
makeAccessChainIdList(NodeData * data,spirv::IdRefList * idsOut)1126 void OutputSPIRVTraverser::makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut)
1127 {
1128 for (size_t index = 0; index < data->idList.size(); ++index)
1129 {
1130 spirv::IdRef indexId = data->idList[index].id;
1131
1132 if (!indexId.valid())
1133 {
1134 // The index is a literal integer, so replace it with an OpConstant id.
1135 indexId = mBuilder.getUintConstant(data->idList[index].literal);
1136 }
1137
1138 idsOut->push_back(indexId);
1139 }
1140 }
1141
makeAccessChainLiteralList(NodeData * data,spirv::LiteralIntegerList * literalsOut)1142 void OutputSPIRVTraverser::makeAccessChainLiteralList(NodeData *data,
1143 spirv::LiteralIntegerList *literalsOut)
1144 {
1145 for (size_t index = 0; index < data->idList.size(); ++index)
1146 {
1147 ASSERT(!data->idList[index].id.valid());
1148 literalsOut->push_back(data->idList[index].literal);
1149 }
1150 }
1151
getAccessChainTypeId(NodeData * data)1152 spirv::IdRef OutputSPIRVTraverser::getAccessChainTypeId(NodeData *data)
1153 {
1154 // Load and store through the access chain may be done in multiple steps. These steps produce
1155 // the following types:
1156 //
1157 // - preSwizzleTypeId
1158 // - postSwizzleTypeId
1159 // - postDynamicComponentTypeId
1160 //
1161 // The last of these types is the final type of the expression this access chain corresponds to.
1162 const AccessChain &accessChain = data->accessChain;
1163
1164 if (accessChain.postDynamicComponentTypeId.valid())
1165 {
1166 return accessChain.postDynamicComponentTypeId;
1167 }
1168 if (accessChain.postSwizzleTypeId.valid())
1169 {
1170 return accessChain.postSwizzleTypeId;
1171 }
1172 ASSERT(accessChain.preSwizzleTypeId.valid());
1173 return accessChain.preSwizzleTypeId;
1174 }
1175
declareConst(TIntermDeclaration * decl)1176 void OutputSPIRVTraverser::declareConst(TIntermDeclaration *decl)
1177 {
1178 const TIntermSequence &sequence = *decl->getSequence();
1179 ASSERT(sequence.size() == 1);
1180
1181 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1182 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1183
1184 TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1185 ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqConst);
1186
1187 TIntermTyped *initializer = assign->getRight();
1188 ASSERT(initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
1189
1190 const TType &type = symbol->getType();
1191 const TVariable *variable = &symbol->variable();
1192
1193 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1194 const spirv::IdRef constId =
1195 createConstant(type, type.getBasicType(), initializer->getConstantValue(),
1196 initializer->isConstantNullValue());
1197
1198 // Remember the id of the variable for future look up.
1199 ASSERT(mSymbolIdMap.count(variable) == 0);
1200 mSymbolIdMap[variable] = constId;
1201
1202 if (!mInGlobalScope)
1203 {
1204 mNodeData.emplace_back();
1205 nodeDataInitRValue(&mNodeData.back(), constId, typeId);
1206 }
1207 }
1208
declareSpecConst(TIntermDeclaration * decl)1209 void OutputSPIRVTraverser::declareSpecConst(TIntermDeclaration *decl)
1210 {
1211 const TIntermSequence &sequence = *decl->getSequence();
1212 ASSERT(sequence.size() == 1);
1213
1214 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1215 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1216
1217 TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1218 ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqSpecConst);
1219
1220 TIntermConstantUnion *initializer = assign->getRight()->getAsConstantUnion();
1221 ASSERT(initializer != nullptr);
1222
1223 const TType &type = symbol->getType();
1224 const TVariable *variable = &symbol->variable();
1225
1226 // All spec consts in ANGLE are initialized to 0.
1227 ASSERT(initializer->isZero(0));
1228
1229 const spirv::IdRef specConstId = mBuilder.declareSpecConst(
1230 type.getBasicType(), type.getLayoutQualifier().location, mBuilder.getName(variable).data());
1231
1232 // Remember the id of the variable for future look up.
1233 ASSERT(mSymbolIdMap.count(variable) == 0);
1234 mSymbolIdMap[variable] = specConstId;
1235 }
1236
createConstant(const TType & type,TBasicType expectedBasicType,const TConstantUnion * constUnion,bool isConstantNullValue)1237 spirv::IdRef OutputSPIRVTraverser::createConstant(const TType &type,
1238 TBasicType expectedBasicType,
1239 const TConstantUnion *constUnion,
1240 bool isConstantNullValue)
1241 {
1242 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1243 spirv::IdRefList componentIds;
1244
1245 // If the object is all zeros, use OpConstantNull to avoid creating a bunch of constants. This
1246 // is not done for basic scalar types as some instructions require an OpConstant and validation
1247 // doesn't accept OpConstantNull (likely a spec bug).
1248 const size_t size = type.getObjectSize();
1249 const TBasicType basicType = type.getBasicType();
1250 const bool isBasicScalar = size == 1 && (basicType == EbtFloat || basicType == EbtInt ||
1251 basicType == EbtUInt || basicType == EbtBool);
1252 const bool useOpConstantNull = isConstantNullValue && !isBasicScalar;
1253 if (useOpConstantNull)
1254 {
1255 return mBuilder.getNullConstant(typeId);
1256 }
1257
1258 if (type.isArray())
1259 {
1260 TType elementType(type);
1261 elementType.toArrayElementType();
1262
1263 // If it's an array constant, get the constant id of each element.
1264 for (unsigned int elementIndex = 0; elementIndex < type.getOutermostArraySize();
1265 ++elementIndex)
1266 {
1267 componentIds.push_back(
1268 createConstant(elementType, expectedBasicType, constUnion, false));
1269 constUnion += elementType.getObjectSize();
1270 }
1271 }
1272 else if (type.getBasicType() == EbtStruct)
1273 {
1274 // If it's a struct constant, get the constant id for each field.
1275 for (const TField *field : type.getStruct()->fields())
1276 {
1277 const TType *fieldType = field->type();
1278 componentIds.push_back(
1279 createConstant(*fieldType, fieldType->getBasicType(), constUnion, false));
1280
1281 constUnion += fieldType->getObjectSize();
1282 }
1283 }
1284 else
1285 {
1286 // Otherwise get the constant id for each component.
1287 ASSERT(expectedBasicType == EbtFloat || expectedBasicType == EbtInt ||
1288 expectedBasicType == EbtUInt || expectedBasicType == EbtBool ||
1289 expectedBasicType == EbtYuvCscStandardEXT);
1290
1291 for (size_t component = 0; component < size; ++component, ++constUnion)
1292 {
1293 spirv::IdRef componentId;
1294
1295 // If the constant has a different type than expected, cast it right away.
1296 TConstantUnion castConstant;
1297 bool valid = castConstant.cast(expectedBasicType, *constUnion);
1298 ASSERT(valid);
1299
1300 switch (castConstant.getType())
1301 {
1302 case EbtFloat:
1303 componentId = mBuilder.getFloatConstant(castConstant.getFConst());
1304 break;
1305 case EbtInt:
1306 componentId = mBuilder.getIntConstant(castConstant.getIConst());
1307 break;
1308 case EbtUInt:
1309 componentId = mBuilder.getUintConstant(castConstant.getUConst());
1310 break;
1311 case EbtBool:
1312 componentId = mBuilder.getBoolConstant(castConstant.getBConst());
1313 break;
1314 case EbtYuvCscStandardEXT:
1315 componentId =
1316 mBuilder.getUintConstant(castConstant.getYuvCscStandardEXTConst());
1317 break;
1318 default:
1319 UNREACHABLE();
1320 }
1321 componentIds.push_back(componentId);
1322 }
1323 }
1324
1325 // If this is a composite, create a composite constant from the components.
1326 if (type.isArray() || type.getBasicType() == EbtStruct || componentIds.size() > 1)
1327 {
1328 return createComplexConstant(type, typeId, componentIds);
1329 }
1330
1331 // Otherwise return the sole component.
1332 ASSERT(componentIds.size() == 1);
1333 return componentIds[0];
1334 }
1335
createComplexConstant(const TType & type,spirv::IdRef typeId,const spirv::IdRefList & parameters)1336 spirv::IdRef OutputSPIRVTraverser::createComplexConstant(const TType &type,
1337 spirv::IdRef typeId,
1338 const spirv::IdRefList ¶meters)
1339 {
1340 ASSERT(!type.isScalar());
1341
1342 if (type.isMatrix() && !type.isArray())
1343 {
1344 ASSERT(parameters.size() == type.getRows() * type.getCols());
1345
1346 // Matrices are constructed from their columns.
1347 spirv::IdRefList columnIds;
1348
1349 const spirv::IdRef columnTypeId =
1350 mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1351
1352 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1353 {
1354 auto columnParametersStart = parameters.begin() + columnIndex * type.getRows();
1355 spirv::IdRefList columnParameters(columnParametersStart,
1356 columnParametersStart + type.getRows());
1357
1358 columnIds.push_back(mBuilder.getCompositeConstant(columnTypeId, columnParameters));
1359 }
1360
1361 return mBuilder.getCompositeConstant(typeId, columnIds);
1362 }
1363
1364 return mBuilder.getCompositeConstant(typeId, parameters);
1365 }
1366
createConstructor(TIntermAggregate * node,spirv::IdRef typeId)1367 spirv::IdRef OutputSPIRVTraverser::createConstructor(TIntermAggregate *node, spirv::IdRef typeId)
1368 {
1369 const TType &type = node->getType();
1370 const TIntermSequence &arguments = *node->getSequence();
1371 const TType &arg0Type = arguments[0]->getAsTyped()->getType();
1372
1373 // In some cases, constructors-with-constant values are not folded. If the constructor is a
1374 // null value, use OpConstantNull to avoid creating a bunch of instructions. Otherwise, the
1375 // constant is created below.
1376 if (node->isConstantNullValue())
1377 {
1378 return mBuilder.getNullConstant(typeId);
1379 }
1380
1381 // Take each constructor argument that is visited and evaluate it as rvalue
1382 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
1383
1384 // Constructors in GLSL can take various shapes, resulting in different translations to SPIR-V
1385 // (in each case, if the parameter doesn't match the type being constructed, it must be cast):
1386 //
1387 // - float(f): This should translate to just f
1388 // - float(v): This should translate to OpCompositeExtract %scalar %v 0
1389 // - float(m): This should translate to OpCompositeExtract %scalar %m 0 0
1390 // - vecN(f): This should translate to OpCompositeConstruct %vecN %f %f .. %f
1391 // - vecN(v1.zy, v2.x): This can technically translate to OpCompositeConstruct with two ids; the
1392 // results of v1.zy and v2.x. However, for simplicity it's easier to generate that
1393 // instruction with three ids; the results of v1.z, v1.y and v2.x (see below where a matrix is
1394 // used as parameter).
1395 // - vecN(m): This takes N components from m in column-major order (for example, vec4
1396 // constructed out of a 4x3 matrix would select components (0,0), (0,1), (0,2) and (1,0)).
1397 // This translates to OpCompositeConstruct with the id of the individual components extracted
1398 // from m.
1399 // - matNxM(f): This creates a diagonal matrix. It generates N OpCompositeConstruct
1400 // instructions for each column (which are vecM), followed by an OpCompositeConstruct that
1401 // constructs the final result.
1402 // - matNxM(m):
1403 // * With m larger than NxM, this extracts a submatrix out of m. It generates
1404 // OpCompositeExtracts for N columns of m, followed by an OpVectorShuffle (swizzle) if the
1405 // rows of m are more than M. OpCompositeConstruct is used to construct the final result.
1406 // * If m is not larger than NxM, an identity matrix is created and superimposed with m.
1407 // OpCompositeExtract is used to extract each component of m (that is necessary), and
1408 // together with the zero or one constants necessary used to create the columns (with
1409 // OpCompositeConstruct). OpCompositeConstruct is used to construct the final result.
1410 // - matNxM(v1.zy, v2.x, ...): Similarly to constructing a vector, a list of single components
1411 // are extracted from the parameters, which are divided up and used to construct each column,
1412 // which is finally constructed into the final result.
1413 //
1414 // Additionally, array and structs are constructed by OpCompositeConstruct followed by ids of
1415 // each parameter which must enumerate every individual element / field.
1416
1417 // In some cases, constructors-with-constant values are not folded such as for large constants.
1418 // Some transformations may also produce constructors-with-constants instead of constants even
1419 // for basic types. These are handled here.
1420 if (node->hasConstantValue())
1421 {
1422 if (!type.isScalar())
1423 {
1424 return createComplexConstant(node->getType(), typeId, parameters);
1425 }
1426
1427 // If a transformation creates scalar(constant), return the constant as-is.
1428 // visitConstantUnion has already cast it to the right type.
1429 if (arguments[0]->getAsConstantUnion() != nullptr)
1430 {
1431 return parameters[0];
1432 }
1433 }
1434
1435 if (type.isArray() || type.getStruct() != nullptr)
1436 {
1437 return createArrayOrStructConstructor(node, typeId, parameters);
1438 }
1439
1440 // The following are simple casts:
1441 //
1442 // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
1443 // - gvecN(vN) (where the argument is a single vector with the same number of components).
1444 // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions). Note that
1445 // matrices are always float, so there's no actual cast and this would be a no-op.
1446 //
1447 const bool isSingleScalarCast = arguments.size() == 1 && type.isScalar() && arg0Type.isScalar();
1448 const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
1449 arg0Type.isVector() &&
1450 type.getNominalSize() == arg0Type.getNominalSize();
1451 const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
1452 arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
1453 type.getRows() == arg0Type.getRows();
1454 if (isSingleScalarCast || isSingleVectorCast || isSingleMatrixCast)
1455 {
1456 return castBasicType(parameters[0], arg0Type, type, nullptr);
1457 }
1458
1459 if (type.isScalar())
1460 {
1461 ASSERT(arguments.size() == 1);
1462 return createConstructorScalarFromNonScalar(node, typeId, parameters);
1463 }
1464
1465 if (type.isVector())
1466 {
1467 if (arguments.size() == 1 && arg0Type.isScalar())
1468 {
1469 return createConstructorVectorFromScalar(arg0Type, type, typeId, parameters);
1470 }
1471 if (arg0Type.isMatrix())
1472 {
1473 // If the first argument is a matrix, it will always have enough components to fill an
1474 // entire vector, so it doesn't matter what's specified after it.
1475 return createConstructorVectorFromMatrix(node, typeId, parameters);
1476 }
1477 return createConstructorVectorFromMultiple(node, typeId, parameters);
1478 }
1479
1480 ASSERT(type.isMatrix());
1481
1482 if (arg0Type.isScalar() && arguments.size() == 1)
1483 {
1484 parameters[0] = castBasicType(parameters[0], arg0Type, type, nullptr);
1485 return createConstructorMatrixFromScalar(node, typeId, parameters);
1486 }
1487 if (arg0Type.isMatrix())
1488 {
1489 return createConstructorMatrixFromMatrix(node, typeId, parameters);
1490 }
1491 return createConstructorMatrixFromVectors(node, typeId, parameters);
1492 }
1493
createArrayOrStructConstructor(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1494 spirv::IdRef OutputSPIRVTraverser::createArrayOrStructConstructor(
1495 TIntermAggregate *node,
1496 spirv::IdRef typeId,
1497 const spirv::IdRefList ¶meters)
1498 {
1499 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1500 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1501 parameters);
1502 return result;
1503 }
1504
createConstructorScalarFromNonScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1505 spirv::IdRef OutputSPIRVTraverser::createConstructorScalarFromNonScalar(
1506 TIntermAggregate *node,
1507 spirv::IdRef typeId,
1508 const spirv::IdRefList ¶meters)
1509 {
1510 ASSERT(parameters.size() == 1);
1511 const TType &type = node->getType();
1512 const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
1513
1514 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(type));
1515
1516 spirv::LiteralIntegerList indices = {spirv::LiteralInteger(0)};
1517 if (arg0Type.isMatrix())
1518 {
1519 indices.push_back(spirv::LiteralInteger(0));
1520 }
1521
1522 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1523 mBuilder.getBasicTypeId(arg0Type.getBasicType(), 1), result,
1524 parameters[0], indices);
1525
1526 TType arg0TypeAsScalar(arg0Type);
1527 arg0TypeAsScalar.toComponentType();
1528
1529 return castBasicType(result, arg0TypeAsScalar, type, nullptr);
1530 }
1531
createConstructorVectorFromScalar(const TType & parameterType,const TType & expectedType,spirv::IdRef typeId,const spirv::IdRefList & parameters)1532 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalar(
1533 const TType ¶meterType,
1534 const TType &expectedType,
1535 spirv::IdRef typeId,
1536 const spirv::IdRefList ¶meters)
1537 {
1538 // vecN(f) translates to OpCompositeConstruct %vecN %f ... %f
1539 ASSERT(parameters.size() == 1);
1540
1541 const spirv::IdRef castParameter =
1542 castBasicType(parameters[0], parameterType, expectedType, nullptr);
1543
1544 spirv::IdRefList replicatedParameter(expectedType.getNominalSize(), castParameter);
1545
1546 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(parameterType));
1547 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1548 replicatedParameter);
1549 return result;
1550 }
1551
createConstructorVectorFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1552 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMatrix(
1553 TIntermAggregate *node,
1554 spirv::IdRef typeId,
1555 const spirv::IdRefList ¶meters)
1556 {
1557 // vecN(m) translates to OpCompositeConstruct %vecN %m[0][0] %m[0][1] ...
1558 spirv::IdRefList extractedComponents;
1559 extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
1560
1561 // Construct the vector with the basic type of the argument, and cast it at end if needed.
1562 ASSERT(parameters.size() == 1);
1563 const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
1564 const TType &expectedType = node->getType();
1565
1566 spirv::IdRef argumentTypeId = typeId;
1567 TType arg0TypeAsVector(arg0Type);
1568 arg0TypeAsVector.setPrimarySize(node->getType().getNominalSize());
1569 arg0TypeAsVector.setSecondarySize(1);
1570
1571 if (arg0Type.getBasicType() != expectedType.getBasicType())
1572 {
1573 argumentTypeId = mBuilder.getTypeData(arg0TypeAsVector, {}).id;
1574 }
1575
1576 spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1577 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), argumentTypeId, result,
1578 extractedComponents);
1579
1580 if (arg0Type.getBasicType() != expectedType.getBasicType())
1581 {
1582 result = castBasicType(result, arg0TypeAsVector, expectedType, nullptr);
1583 }
1584
1585 return result;
1586 }
1587
createConstructorVectorFromMultiple(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1588 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMultiple(
1589 TIntermAggregate *node,
1590 spirv::IdRef typeId,
1591 const spirv::IdRefList ¶meters)
1592 {
1593 const TType &type = node->getType();
1594 // vecN(v1.zy, v2.x) translates to OpCompositeConstruct %vecN %v1.z %v1.y %v2.x
1595 spirv::IdRefList extractedComponents;
1596 extractComponents(node, type.getNominalSize(), parameters, &extractedComponents);
1597
1598 // Handle the case where a matrix is used in the constructor anywhere but the first place. In
1599 // that case, the components extracted from the matrix might need casting to the right type.
1600 const TIntermSequence &arguments = *node->getSequence();
1601 for (size_t argumentIndex = 0, componentIndex = 0;
1602 argumentIndex < arguments.size() && componentIndex < extractedComponents.size();
1603 ++argumentIndex)
1604 {
1605 TIntermNode *argument = arguments[argumentIndex];
1606 const TType &argumentType = argument->getAsTyped()->getType();
1607 if (argumentType.isScalar() || argumentType.isVector())
1608 {
1609 // extractComponents already casts scalar and vector components.
1610 componentIndex += argumentType.getNominalSize();
1611 continue;
1612 }
1613
1614 TType componentType(argumentType);
1615 componentType.toComponentType();
1616
1617 for (uint8_t columnIndex = 0;
1618 columnIndex < argumentType.getCols() && componentIndex < extractedComponents.size();
1619 ++columnIndex)
1620 {
1621 for (uint8_t rowIndex = 0;
1622 rowIndex < argumentType.getRows() && componentIndex < extractedComponents.size();
1623 ++rowIndex, ++componentIndex)
1624 {
1625 extractedComponents[componentIndex] = castBasicType(
1626 extractedComponents[componentIndex], componentType, type, nullptr);
1627 }
1628 }
1629
1630 // Matrices all have enough components to fill a vector, so it's impossible to need to visit
1631 // any other arguments that may come after.
1632 ASSERT(componentIndex == extractedComponents.size());
1633 }
1634
1635 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1636 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1637 extractedComponents);
1638 return result;
1639 }
1640
createConstructorMatrixFromScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1641 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromScalar(
1642 TIntermAggregate *node,
1643 spirv::IdRef typeId,
1644 const spirv::IdRefList ¶meters)
1645 {
1646 // matNxM(f) translates to
1647 //
1648 // %c0 = OpCompositeConstruct %vecM %f %zero %zero ..
1649 // %c1 = OpCompositeConstruct %vecM %zero %f %zero ..
1650 // %c2 = OpCompositeConstruct %vecM %zero %zero %f ..
1651 // ...
1652 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1653
1654 const TType &type = node->getType();
1655 const spirv::IdRef scalarId = parameters[0];
1656 spirv::IdRef zeroId;
1657
1658 SpirvDecorations decorations = mBuilder.getDecorations(type);
1659
1660 switch (type.getBasicType())
1661 {
1662 case EbtFloat:
1663 zeroId = mBuilder.getFloatConstant(0);
1664 break;
1665 case EbtInt:
1666 zeroId = mBuilder.getIntConstant(0);
1667 break;
1668 case EbtUInt:
1669 zeroId = mBuilder.getUintConstant(0);
1670 break;
1671 case EbtBool:
1672 zeroId = mBuilder.getBoolConstant(0);
1673 break;
1674 default:
1675 UNREACHABLE();
1676 }
1677
1678 spirv::IdRefList componentIds(type.getRows(), zeroId);
1679 spirv::IdRefList columnIds;
1680
1681 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1682
1683 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1684 {
1685 columnIds.push_back(mBuilder.getNewId(decorations));
1686
1687 // Place the scalar at the correct index (diagonal of the matrix, i.e. row == col).
1688 if (columnIndex < type.getRows())
1689 {
1690 componentIds[columnIndex] = scalarId;
1691 }
1692 if (columnIndex > 0 && columnIndex <= type.getRows())
1693 {
1694 componentIds[columnIndex - 1] = zeroId;
1695 }
1696
1697 // Create the column.
1698 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1699 columnIds.back(), componentIds);
1700 }
1701
1702 // Create the matrix out of the columns.
1703 const spirv::IdRef result = mBuilder.getNewId(decorations);
1704 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1705 columnIds);
1706 return result;
1707 }
1708
createConstructorMatrixFromVectors(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1709 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromVectors(
1710 TIntermAggregate *node,
1711 spirv::IdRef typeId,
1712 const spirv::IdRefList ¶meters)
1713 {
1714 // matNxM(v1.zy, v2.x, ...) translates to:
1715 //
1716 // %c0 = OpCompositeConstruct %vecM %v1.z %v1.y %v2.x ..
1717 // ...
1718 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1719
1720 const TType &type = node->getType();
1721
1722 SpirvDecorations decorations = mBuilder.getDecorations(type);
1723
1724 spirv::IdRefList extractedComponents;
1725 extractComponents(node, type.getCols() * type.getRows(), parameters, &extractedComponents);
1726
1727 spirv::IdRefList columnIds;
1728
1729 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1730
1731 // Chunk up the extracted components by column and construct intermediary vectors.
1732 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1733 {
1734 columnIds.push_back(mBuilder.getNewId(decorations));
1735
1736 auto componentsStart = extractedComponents.begin() + columnIndex * type.getRows();
1737 const spirv::IdRefList componentIds(componentsStart, componentsStart + type.getRows());
1738
1739 // Create the column.
1740 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1741 columnIds.back(), componentIds);
1742 }
1743
1744 const spirv::IdRef result = mBuilder.getNewId(decorations);
1745 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1746 columnIds);
1747 return result;
1748 }
1749
createConstructorMatrixFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1750 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromMatrix(
1751 TIntermAggregate *node,
1752 spirv::IdRef typeId,
1753 const spirv::IdRefList ¶meters)
1754 {
1755 // matNxM(m) translates to:
1756 //
1757 // - If m is SxR where S>=N and R>=M:
1758 //
1759 // %c0 = OpCompositeExtract %vecR %m 0
1760 // %c1 = OpCompositeExtract %vecR %m 1
1761 // ...
1762 // // If R (column size of m) != M, OpVectorShuffle to extract M components out of %ci.
1763 // ...
1764 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1765 //
1766 // - Otherwise, an identity matrix is created and superimposed by m:
1767 //
1768 // %c0 = OpCompositeConstruct %vecM %m[0][0] %m[0][1] %0 %0
1769 // %c1 = OpCompositeConstruct %vecM %m[1][0] %m[1][1] %0 %0
1770 // %c2 = OpCompositeConstruct %vecM %m[2][0] %m[2][1] %1 %0
1771 // %c3 = OpCompositeConstruct %vecM %0 %0 %0 %1
1772 // %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 %c3
1773
1774 const TType &type = node->getType();
1775 const TType ¶meterType = (*node->getSequence())[0]->getAsTyped()->getType();
1776
1777 SpirvDecorations decorations = mBuilder.getDecorations(type);
1778
1779 ASSERT(parameters.size() == 1);
1780
1781 spirv::IdRefList columnIds;
1782
1783 const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1784
1785 if (parameterType.getCols() >= type.getCols() && parameterType.getRows() >= type.getRows())
1786 {
1787 // If the parameter is a larger matrix than the constructor type, extract the columns
1788 // directly and potentially swizzle them.
1789 TType paramColumnType(parameterType);
1790 paramColumnType.toMatrixColumnType();
1791 const spirv::IdRef paramColumnTypeId = mBuilder.getTypeData(paramColumnType, {}).id;
1792
1793 const bool needsSwizzle = parameterType.getRows() > type.getRows();
1794 spirv::LiteralIntegerList swizzle = {spirv::LiteralInteger(0), spirv::LiteralInteger(1),
1795 spirv::LiteralInteger(2), spirv::LiteralInteger(3)};
1796 swizzle.resize_down(type.getRows());
1797
1798 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1799 {
1800 // Extract the column.
1801 const spirv::IdRef parameterColumnId = mBuilder.getNewId(decorations);
1802 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), paramColumnTypeId,
1803 parameterColumnId, parameters[0],
1804 {spirv::LiteralInteger(columnIndex)});
1805
1806 // If the column has too many components, select the appropriate number of components.
1807 spirv::IdRef constructorColumnId = parameterColumnId;
1808 if (needsSwizzle)
1809 {
1810 constructorColumnId = mBuilder.getNewId(decorations);
1811 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1812 constructorColumnId, parameterColumnId, parameterColumnId,
1813 swizzle);
1814 }
1815
1816 columnIds.push_back(constructorColumnId);
1817 }
1818 }
1819 else
1820 {
1821 // Otherwise create an identity matrix and fill in the components that can be taken from the
1822 // given parameter.
1823 TType paramComponentType(parameterType);
1824 paramComponentType.toComponentType();
1825 const spirv::IdRef paramComponentTypeId = mBuilder.getTypeData(paramComponentType, {}).id;
1826
1827 for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1828 {
1829 spirv::IdRefList componentIds;
1830
1831 for (uint8_t componentIndex = 0; componentIndex < type.getRows(); ++componentIndex)
1832 {
1833 // Take the component from the constructor parameter if possible.
1834 spirv::IdRef componentId;
1835 if (componentIndex < parameterType.getRows() &&
1836 columnIndex < parameterType.getCols())
1837 {
1838 componentId = mBuilder.getNewId(decorations);
1839 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1840 paramComponentTypeId, componentId, parameters[0],
1841 {spirv::LiteralInteger(columnIndex),
1842 spirv::LiteralInteger(componentIndex)});
1843 }
1844 else
1845 {
1846 const bool isOnDiagonal = columnIndex == componentIndex;
1847 switch (type.getBasicType())
1848 {
1849 case EbtFloat:
1850 componentId = mBuilder.getFloatConstant(isOnDiagonal ? 1.0f : 0.0f);
1851 break;
1852 case EbtInt:
1853 componentId = mBuilder.getIntConstant(isOnDiagonal ? 1 : 0);
1854 break;
1855 case EbtUInt:
1856 componentId = mBuilder.getUintConstant(isOnDiagonal ? 1 : 0);
1857 break;
1858 case EbtBool:
1859 componentId = mBuilder.getBoolConstant(isOnDiagonal);
1860 break;
1861 default:
1862 UNREACHABLE();
1863 }
1864 }
1865
1866 componentIds.push_back(componentId);
1867 }
1868
1869 // Create the column vector.
1870 columnIds.push_back(mBuilder.getNewId(decorations));
1871 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1872 columnIds.back(), componentIds);
1873 }
1874 }
1875
1876 const spirv::IdRef result = mBuilder.getNewId(decorations);
1877 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1878 columnIds);
1879 return result;
1880 }
1881
loadAllParams(TIntermOperator * node,size_t skipCount,spirv::IdRefList * paramTypeIds)1882 spirv::IdRefList OutputSPIRVTraverser::loadAllParams(TIntermOperator *node,
1883 size_t skipCount,
1884 spirv::IdRefList *paramTypeIds)
1885 {
1886 const size_t parameterCount = node->getChildCount();
1887 spirv::IdRefList parameters;
1888
1889 for (size_t paramIndex = 0; paramIndex + skipCount < parameterCount; ++paramIndex)
1890 {
1891 // Take each parameter that is visited and evaluate it as rvalue
1892 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1893
1894 spirv::IdRef paramTypeId;
1895 const spirv::IdRef paramValue = accessChainLoad(
1896 ¶m, node->getChildNode(paramIndex)->getAsTyped()->getType(), ¶mTypeId);
1897
1898 parameters.push_back(paramValue);
1899 if (paramTypeIds)
1900 {
1901 paramTypeIds->push_back(paramTypeId);
1902 }
1903 }
1904
1905 return parameters;
1906 }
1907
extractComponents(TIntermAggregate * node,size_t componentCount,const spirv::IdRefList & parameters,spirv::IdRefList * extractedComponentsOut)1908 void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node,
1909 size_t componentCount,
1910 const spirv::IdRefList ¶meters,
1911 spirv::IdRefList *extractedComponentsOut)
1912 {
1913 // A helper function that takes the list of parameters passed to a constructor (which may have
1914 // more components than necessary) and extracts the first componentCount components.
1915 const TIntermSequence &arguments = *node->getSequence();
1916
1917 const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
1918 const TType &expectedType = node->getType();
1919
1920 ASSERT(arguments.size() == parameters.size());
1921
1922 for (size_t argumentIndex = 0;
1923 argumentIndex < arguments.size() && extractedComponentsOut->size() < componentCount;
1924 ++argumentIndex)
1925 {
1926 TIntermNode *argument = arguments[argumentIndex];
1927 const TType &argumentType = argument->getAsTyped()->getType();
1928 const spirv::IdRef parameterId = parameters[argumentIndex];
1929
1930 if (argumentType.isScalar())
1931 {
1932 // For scalar parameters, there's nothing to do other than a potential cast.
1933 const spirv::IdRef castParameterId =
1934 argument->getAsConstantUnion()
1935 ? parameterId
1936 : castBasicType(parameterId, argumentType, expectedType, nullptr);
1937 extractedComponentsOut->push_back(castParameterId);
1938 continue;
1939 }
1940 if (argumentType.isVector())
1941 {
1942 TType componentType(argumentType);
1943 componentType.toComponentType();
1944 componentType.setBasicType(expectedType.getBasicType());
1945 const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1946
1947 // Cast the whole vector parameter in one go.
1948 const spirv::IdRef castParameterId =
1949 argument->getAsConstantUnion()
1950 ? parameterId
1951 : castBasicType(parameterId, argumentType, expectedType, nullptr);
1952
1953 // For vector parameters, take components out of the vector one by one.
1954 for (uint8_t componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
1955 extractedComponentsOut->size() < componentCount;
1956 ++componentIndex)
1957 {
1958 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1959 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1960 componentTypeId, componentId, castParameterId,
1961 {spirv::LiteralInteger(componentIndex)});
1962
1963 extractedComponentsOut->push_back(componentId);
1964 }
1965 continue;
1966 }
1967
1968 ASSERT(argumentType.isMatrix());
1969
1970 TType componentType(argumentType);
1971 componentType.toComponentType();
1972 const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1973
1974 // For matrix parameters, take components out of the matrix one by one in column-major
1975 // order. No cast is done here; it would only be required for vector constructors with
1976 // matrix parameters, in which case the resulting vector is cast in the end.
1977 for (uint8_t columnIndex = 0; columnIndex < argumentType.getCols() &&
1978 extractedComponentsOut->size() < componentCount;
1979 ++columnIndex)
1980 {
1981 for (uint8_t componentIndex = 0; componentIndex < argumentType.getRows() &&
1982 extractedComponentsOut->size() < componentCount;
1983 ++componentIndex)
1984 {
1985 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1986 spirv::WriteCompositeExtract(
1987 mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId, componentId,
1988 parameterId,
1989 {spirv::LiteralInteger(columnIndex), spirv::LiteralInteger(componentIndex)});
1990
1991 extractedComponentsOut->push_back(componentId);
1992 }
1993 }
1994 }
1995 }
1996
startShortCircuit(TIntermBinary * node)1997 void OutputSPIRVTraverser::startShortCircuit(TIntermBinary *node)
1998 {
1999 // Emulate && and || as such:
2000 //
2001 // || => if (!left) result = right
2002 // && => if ( left) result = right
2003 //
2004 // When this function is called, |left| has already been visited, so it creates the appropriate
2005 // |if| construct in preparation for visiting |right|.
2006
2007 // Load |left| and replace the access chain with an rvalue that's the result.
2008 spirv::IdRef typeId;
2009 const spirv::IdRef left =
2010 accessChainLoad(&mNodeData.back(), node->getLeft()->getType(), &typeId);
2011 nodeDataInitRValue(&mNodeData.back(), left, typeId);
2012
2013 // Keep the id of the block |left| was evaluated in.
2014 mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
2015
2016 // Two blocks necessary, one for the |if| block, and one for the merge block.
2017 mBuilder.startConditional(2, false, false);
2018
2019 // Generate the branch instructions.
2020 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
2021
2022 const spirv::IdRef mergeBlock = conditional->blockIds.back();
2023 const spirv::IdRef ifBlock = conditional->blockIds.front();
2024 const spirv::IdRef trueBlock = node->getOp() == EOpLogicalAnd ? ifBlock : mergeBlock;
2025 const spirv::IdRef falseBlock = node->getOp() == EOpLogicalOr ? ifBlock : mergeBlock;
2026
2027 // Note that no logical not is necessary. For ||, the branch will target the merge block in the
2028 // true case.
2029 mBuilder.writeBranchConditional(left, trueBlock, falseBlock, mergeBlock);
2030 }
2031
endShortCircuit(TIntermBinary * node,spirv::IdRef * typeId)2032 spirv::IdRef OutputSPIRVTraverser::endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId)
2033 {
2034 // Load the right hand side.
2035 const spirv::IdRef right =
2036 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
2037 mNodeData.pop_back();
2038
2039 // Get the id of the block |right| is evaluated in.
2040 const spirv::IdRef rightBlockId = mBuilder.getSpirvCurrentFunctionBlockId();
2041
2042 // And the cached id of the block |left| is evaluated in.
2043 ASSERT(mNodeData.back().idList.size() == 1);
2044 const spirv::IdRef leftBlockId = mNodeData.back().idList[0].id;
2045 mNodeData.back().idList.clear();
2046
2047 // Move on to the merge block.
2048 mBuilder.writeBranchConditionalBlockEnd();
2049
2050 // Pop from the conditional stack.
2051 mBuilder.endConditional();
2052
2053 // Get the previously loaded result of the left hand side.
2054 *typeId = mNodeData.back().accessChain.baseTypeId;
2055 const spirv::IdRef left = mNodeData.back().baseId;
2056
2057 // Create an OpPhi instruction that selects either the |left| or |right| based on which block
2058 // was traversed.
2059 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2060
2061 spirv::WritePhi(
2062 mBuilder.getSpirvCurrentFunctionBlock(), *typeId, result,
2063 {spirv::PairIdRefIdRef{left, leftBlockId}, spirv::PairIdRefIdRef{right, rightBlockId}});
2064
2065 return result;
2066 }
2067
createFunctionCall(TIntermAggregate * node,spirv::IdRef resultTypeId)2068 spirv::IdRef OutputSPIRVTraverser::createFunctionCall(TIntermAggregate *node,
2069 spirv::IdRef resultTypeId)
2070 {
2071 const TFunction *function = node->getFunction();
2072 ASSERT(function);
2073
2074 ASSERT(mFunctionIdMap.count(function) > 0);
2075 const spirv::IdRef functionId = mFunctionIdMap[function].functionId;
2076
2077 // Get the list of parameters passed to the function. The function parameters can only be
2078 // memory variables, or if the function argument is |const|, an rvalue.
2079 //
2080 // For opaque uniforms, pass it directly as lvalue.
2081 //
2082 // For in variables:
2083 //
2084 // - If the parameter is const, pass it directly as rvalue, otherwise
2085 // - Write it to a temp variable first and pass that.
2086 //
2087 // For out variables:
2088 //
2089 // - Pass a temporary variable. After the function call, copy that variable to the parameter.
2090 //
2091 // For inout variables:
2092 //
2093 // - Write the parameter to a temp variable and pass that. After the function call, copy that
2094 // variable back to the parameter.
2095 //
2096 // Note that in GLSL, in parameters are considered "copied" to the function. In SPIR-V, every
2097 // parameter is implicitly inout. If a function takes an in parameter and modifies it, the
2098 // caller has to ensure that it calls the function with a copy. Currently, the functions don't
2099 // track whether an in parameter is modified, so we conservatively assume it is. Even for out
2100 // and inout parameters, GLSL expects each function to operate on their local copy until the end
2101 // of the function; this has observable side effects if the out variable aliases another
2102 // variable the function has access to (another out variable, a global variable etc).
2103 //
2104 const size_t parameterCount = node->getChildCount();
2105 spirv::IdRefList parameters;
2106 spirv::IdRefList tempVarIds(parameterCount);
2107 spirv::IdRefList tempVarTypeIds(parameterCount);
2108
2109 for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2110 {
2111 const TType ¶mType = function->getParam(paramIndex)->getType();
2112 const TType &argType = node->getChildNode(paramIndex)->getAsTyped()->getType();
2113 const TQualifier ¶mQualifier = paramType.getQualifier();
2114 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2115
2116 spirv::IdRef paramValue;
2117
2118 if (paramQualifier == EvqParamConst)
2119 {
2120 // |const| parameters are passed as rvalue.
2121 paramValue = accessChainLoad(¶m, argType, nullptr);
2122 }
2123 else if (IsOpaqueType(paramType.getBasicType()))
2124 {
2125 // Opaque uniforms are passed by pointer.
2126 paramValue = accessChainCollapse(¶m);
2127 }
2128 else
2129 {
2130 ASSERT(paramQualifier == EvqParamIn || paramQualifier == EvqParamOut ||
2131 paramQualifier == EvqParamInOut);
2132
2133 // Need to create a temp variable and pass that.
2134 tempVarTypeIds[paramIndex] = mBuilder.getTypeData(paramType, {}).id;
2135 tempVarIds[paramIndex] = mBuilder.declareVariable(
2136 tempVarTypeIds[paramIndex], spv::StorageClassFunction,
2137 mBuilder.getDecorations(argType), nullptr, "param", nullptr);
2138
2139 // If it's an in or inout parameter, the temp variable needs to be initialized with the
2140 // value of the parameter first.
2141 if (paramQualifier == EvqParamIn || paramQualifier == EvqParamInOut)
2142 {
2143 paramValue = accessChainLoad(¶m, argType, nullptr);
2144 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVarIds[paramIndex],
2145 paramValue, nullptr);
2146 }
2147
2148 paramValue = tempVarIds[paramIndex];
2149 }
2150
2151 parameters.push_back(paramValue);
2152 }
2153
2154 // Make the actual function call.
2155 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2156 spirv::WriteFunctionCall(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
2157 functionId, parameters);
2158
2159 // Copy from the out and inout temp variables back to the original parameters.
2160 for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2161 {
2162 if (!tempVarIds[paramIndex].valid())
2163 {
2164 continue;
2165 }
2166
2167 const TType ¶mType = function->getParam(paramIndex)->getType();
2168 const TType &argType = node->getChildNode(paramIndex)->getAsTyped()->getType();
2169 const TQualifier ¶mQualifier = paramType.getQualifier();
2170 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2171
2172 if (paramQualifier == EvqParamIn)
2173 {
2174 continue;
2175 }
2176
2177 // Copy from the temp variable to the parameter.
2178 NodeData tempVarData;
2179 nodeDataInitLValue(&tempVarData, tempVarIds[paramIndex], tempVarTypeIds[paramIndex],
2180 spv::StorageClassFunction, {});
2181 const spirv::IdRef tempVarValue = accessChainLoad(&tempVarData, argType, nullptr);
2182 accessChainStore(¶m, tempVarValue, function->getParam(paramIndex)->getType());
2183 }
2184
2185 return result;
2186 }
2187
visitArrayLength(TIntermUnary * node)2188 void OutputSPIRVTraverser::visitArrayLength(TIntermUnary *node)
2189 {
2190 // .length() on sized arrays is already constant folded, so this operation only applies to
2191 // ssbo[N].last_member.length(). OpArrayLength takes the ssbo block *pointer* and the field
2192 // index of last_member, so those need to be extracted from the access chain. Additionally,
2193 // OpArrayLength produces an unsigned int while GLSL produces an int, so a final cast is
2194 // necessary.
2195
2196 // Inspect the children. There are two possibilities:
2197 //
2198 // - last_member.length(): In this case, the id of the nameless ssbo is used.
2199 // - ssbo.last_member.length(): In this case, the id of the variable |ssbo| itself is used.
2200 // - ssbo[N][M].last_member.length(): In this case, the access chain |ssbo N M| is used.
2201 //
2202 // We can't visit the child in its entirety as it will create the access chain |ssbo N M field|
2203 // which is not useful.
2204
2205 spirv::IdRef accessChainId;
2206 spirv::LiteralInteger fieldIndex;
2207
2208 if (node->getOperand()->getAsSymbolNode())
2209 {
2210 // If the operand is a symbol referencing the last member of a nameless interface block,
2211 // visit the symbol to get the id of the interface block.
2212 node->getOperand()->getAsSymbolNode()->traverse(this);
2213
2214 // The access chain must only include the base id + one literal field index.
2215 ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
2216
2217 accessChainId = mNodeData.back().baseId;
2218 fieldIndex = mNodeData.back().idList.back().literal;
2219 }
2220 else
2221 {
2222 // Otherwise make sure not to traverse the field index selection node so that the access
2223 // chain would not include it.
2224 TIntermBinary *fieldSelectionNode = node->getOperand()->getAsBinaryNode();
2225 ASSERT(fieldSelectionNode && fieldSelectionNode->getOp() == EOpIndexDirectInterfaceBlock);
2226
2227 TIntermTyped *interfaceBlockExpression = fieldSelectionNode->getLeft();
2228 TIntermConstantUnion *indexNode = fieldSelectionNode->getRight()->getAsConstantUnion();
2229 ASSERT(indexNode);
2230
2231 // Visit the expression.
2232 interfaceBlockExpression->traverse(this);
2233
2234 accessChainId = accessChainCollapse(&mNodeData.back());
2235 fieldIndex = spirv::LiteralInteger(indexNode->getIConst(0));
2236 }
2237
2238 // Get the int and uint type ids.
2239 const spirv::IdRef intTypeId = mBuilder.getBasicTypeId(EbtInt, 1);
2240 const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
2241
2242 // Generate the instruction.
2243 const spirv::IdRef resultId = mBuilder.getNewId({});
2244 spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
2245 accessChainId, fieldIndex);
2246
2247 // Cast to int.
2248 const spirv::IdRef castResultId = mBuilder.getNewId({});
2249 spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId, resultId);
2250
2251 // Replace the access chain with an rvalue that's the result.
2252 nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
2253 }
2254
IsShortCircuitNeeded(TIntermOperator * node)2255 bool IsShortCircuitNeeded(TIntermOperator *node)
2256 {
2257 TOperator op = node->getOp();
2258
2259 // Short circuit is only necessary for && and ||.
2260 if (op != EOpLogicalAnd && op != EOpLogicalOr)
2261 {
2262 return false;
2263 }
2264
2265 ASSERT(node->getChildCount() == 2);
2266
2267 // If the right hand side does not have side effects, short-circuiting is unnecessary.
2268 // TODO: experiment with the performance of OpLogicalAnd/Or vs short-circuit based on the
2269 // complexity of the right hand side expression. We could potentially only allow
2270 // OpLogicalAnd/Or if the right hand side is a constant or an access chain and have more complex
2271 // expressions be placed inside an if block. http://anglebug.com/40096715
2272 return node->getChildNode(1)->getAsTyped()->hasSideEffects();
2273 }
2274
2275 using WriteUnaryOp = void (*)(spirv::Blob *blob,
2276 spirv::IdResultType idResultType,
2277 spirv::IdResult idResult,
2278 spirv::IdRef operand);
2279 using WriteBinaryOp = void (*)(spirv::Blob *blob,
2280 spirv::IdResultType idResultType,
2281 spirv::IdResult idResult,
2282 spirv::IdRef operand1,
2283 spirv::IdRef operand2);
2284 using WriteTernaryOp = void (*)(spirv::Blob *blob,
2285 spirv::IdResultType idResultType,
2286 spirv::IdResult idResult,
2287 spirv::IdRef operand1,
2288 spirv::IdRef operand2,
2289 spirv::IdRef operand3);
2290 using WriteQuaternaryOp = void (*)(spirv::Blob *blob,
2291 spirv::IdResultType idResultType,
2292 spirv::IdResult idResult,
2293 spirv::IdRef operand1,
2294 spirv::IdRef operand2,
2295 spirv::IdRef operand3,
2296 spirv::IdRef operand4);
2297 using WriteAtomicOp = void (*)(spirv::Blob *blob,
2298 spirv::IdResultType idResultType,
2299 spirv::IdResult idResult,
2300 spirv::IdRef pointer,
2301 spirv::IdScope scope,
2302 spirv::IdMemorySemantics semantics,
2303 spirv::IdRef value);
2304
visitOperator(TIntermOperator * node,spirv::IdRef resultTypeId)2305 spirv::IdRef OutputSPIRVTraverser::visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId)
2306 {
2307 // Handle special groups.
2308 const TOperator op = node->getOp();
2309 if (op == EOpEqual || op == EOpNotEqual)
2310 {
2311 return createCompare(node, resultTypeId);
2312 }
2313 if (BuiltInGroup::IsAtomicMemory(op) || BuiltInGroup::IsImageAtomic(op))
2314 {
2315 return createAtomicBuiltIn(node, resultTypeId);
2316 }
2317 if (BuiltInGroup::IsImage(op) || BuiltInGroup::IsTexture(op))
2318 {
2319 return createImageTextureBuiltIn(node, resultTypeId);
2320 }
2321 if (op == EOpSubpassLoad)
2322 {
2323 return createSubpassLoadBuiltIn(node, resultTypeId);
2324 }
2325 if (BuiltInGroup::IsInterpolationFS(op))
2326 {
2327 return createInterpolate(node, resultTypeId);
2328 }
2329
2330 const size_t childCount = node->getChildCount();
2331 TIntermTyped *firstChild = node->getChildNode(0)->getAsTyped();
2332 TIntermTyped *secondChild = childCount > 1 ? node->getChildNode(1)->getAsTyped() : nullptr;
2333
2334 const TType &firstOperandType = firstChild->getType();
2335 const TBasicType basicType = firstOperandType.getBasicType();
2336 const bool isFloat = basicType == EbtFloat;
2337 const bool isUnsigned = basicType == EbtUInt;
2338 const bool isBool = basicType == EbtBool;
2339 // Whether this is a pre/post increment/decrement operator.
2340 bool isIncrementOrDecrement = false;
2341 // Whether the operation needs to be applied column by column.
2342 bool operateOnColumns =
2343 childCount == 2 && (firstChild->getType().isMatrix() || secondChild->getType().isMatrix());
2344 // Whether the operands need to be swapped in the (binary) instruction
2345 bool binarySwapOperands = false;
2346 // Whether the instruction is matrix/scalar. This is implemented with matrix*(1/scalar).
2347 bool binaryInvertSecondParameter = false;
2348 // Whether the scalar operand needs to be extended to match the other operand which is a vector
2349 // (in a binary or extended operation).
2350 bool extendScalarToVector = true;
2351 // Some built-ins have out parameters at the end of the list of parameters.
2352 size_t lvalueCount = 0;
2353
2354 WriteUnaryOp writeUnaryOp = nullptr;
2355 WriteBinaryOp writeBinaryOp = nullptr;
2356 WriteTernaryOp writeTernaryOp = nullptr;
2357 WriteQuaternaryOp writeQuaternaryOp = nullptr;
2358
2359 // Some operators are implemented with an extended instruction.
2360 spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
2361
2362 switch (op)
2363 {
2364 case EOpNegative:
2365 operateOnColumns = firstOperandType.isMatrix();
2366 if (isFloat)
2367 writeUnaryOp = spirv::WriteFNegate;
2368 else
2369 writeUnaryOp = spirv::WriteSNegate;
2370 break;
2371 case EOpPositive:
2372 // This is a noop.
2373 return accessChainLoad(&mNodeData.back(), firstOperandType, nullptr);
2374
2375 case EOpLogicalNot:
2376 case EOpNotComponentWise:
2377 writeUnaryOp = spirv::WriteLogicalNot;
2378 break;
2379 case EOpBitwiseNot:
2380 writeUnaryOp = spirv::WriteNot;
2381 break;
2382
2383 case EOpPostIncrement:
2384 case EOpPreIncrement:
2385 isIncrementOrDecrement = true;
2386 operateOnColumns = firstOperandType.isMatrix();
2387 if (isFloat)
2388 writeBinaryOp = spirv::WriteFAdd;
2389 else
2390 writeBinaryOp = spirv::WriteIAdd;
2391 break;
2392 case EOpPostDecrement:
2393 case EOpPreDecrement:
2394 isIncrementOrDecrement = true;
2395 operateOnColumns = firstOperandType.isMatrix();
2396 if (isFloat)
2397 writeBinaryOp = spirv::WriteFSub;
2398 else
2399 writeBinaryOp = spirv::WriteISub;
2400 break;
2401
2402 case EOpAdd:
2403 case EOpAddAssign:
2404 if (isFloat)
2405 writeBinaryOp = spirv::WriteFAdd;
2406 else
2407 writeBinaryOp = spirv::WriteIAdd;
2408 break;
2409 case EOpSub:
2410 case EOpSubAssign:
2411 if (isFloat)
2412 writeBinaryOp = spirv::WriteFSub;
2413 else
2414 writeBinaryOp = spirv::WriteISub;
2415 break;
2416 case EOpMul:
2417 case EOpMulAssign:
2418 case EOpMatrixCompMult:
2419 if (isFloat)
2420 writeBinaryOp = spirv::WriteFMul;
2421 else
2422 writeBinaryOp = spirv::WriteIMul;
2423 break;
2424 case EOpDiv:
2425 case EOpDivAssign:
2426 if (isFloat)
2427 {
2428 if (firstOperandType.isMatrix() && secondChild->getType().isScalar())
2429 {
2430 writeBinaryOp = spirv::WriteMatrixTimesScalar;
2431 binaryInvertSecondParameter = true;
2432 operateOnColumns = false;
2433 extendScalarToVector = false;
2434 }
2435 else
2436 {
2437 writeBinaryOp = spirv::WriteFDiv;
2438 }
2439 }
2440 else if (isUnsigned)
2441 writeBinaryOp = spirv::WriteUDiv;
2442 else
2443 writeBinaryOp = spirv::WriteSDiv;
2444 break;
2445 case EOpIMod:
2446 case EOpIModAssign:
2447 if (isFloat)
2448 writeBinaryOp = spirv::WriteFMod;
2449 else if (isUnsigned)
2450 writeBinaryOp = spirv::WriteUMod;
2451 else
2452 writeBinaryOp = spirv::WriteSMod;
2453 break;
2454
2455 case EOpEqualComponentWise:
2456 if (isFloat)
2457 writeBinaryOp = spirv::WriteFOrdEqual;
2458 else if (isBool)
2459 writeBinaryOp = spirv::WriteLogicalEqual;
2460 else
2461 writeBinaryOp = spirv::WriteIEqual;
2462 break;
2463 case EOpNotEqualComponentWise:
2464 if (isFloat)
2465 writeBinaryOp = spirv::WriteFUnordNotEqual;
2466 else if (isBool)
2467 writeBinaryOp = spirv::WriteLogicalNotEqual;
2468 else
2469 writeBinaryOp = spirv::WriteINotEqual;
2470 break;
2471 case EOpLessThan:
2472 case EOpLessThanComponentWise:
2473 if (isFloat)
2474 writeBinaryOp = spirv::WriteFOrdLessThan;
2475 else if (isUnsigned)
2476 writeBinaryOp = spirv::WriteULessThan;
2477 else
2478 writeBinaryOp = spirv::WriteSLessThan;
2479 break;
2480 case EOpGreaterThan:
2481 case EOpGreaterThanComponentWise:
2482 if (isFloat)
2483 writeBinaryOp = spirv::WriteFOrdGreaterThan;
2484 else if (isUnsigned)
2485 writeBinaryOp = spirv::WriteUGreaterThan;
2486 else
2487 writeBinaryOp = spirv::WriteSGreaterThan;
2488 break;
2489 case EOpLessThanEqual:
2490 case EOpLessThanEqualComponentWise:
2491 if (isFloat)
2492 writeBinaryOp = spirv::WriteFOrdLessThanEqual;
2493 else if (isUnsigned)
2494 writeBinaryOp = spirv::WriteULessThanEqual;
2495 else
2496 writeBinaryOp = spirv::WriteSLessThanEqual;
2497 break;
2498 case EOpGreaterThanEqual:
2499 case EOpGreaterThanEqualComponentWise:
2500 if (isFloat)
2501 writeBinaryOp = spirv::WriteFOrdGreaterThanEqual;
2502 else if (isUnsigned)
2503 writeBinaryOp = spirv::WriteUGreaterThanEqual;
2504 else
2505 writeBinaryOp = spirv::WriteSGreaterThanEqual;
2506 break;
2507
2508 case EOpVectorTimesScalar:
2509 case EOpVectorTimesScalarAssign:
2510 if (isFloat)
2511 {
2512 writeBinaryOp = spirv::WriteVectorTimesScalar;
2513 binarySwapOperands = node->getChildNode(1)->getAsTyped()->getType().isVector();
2514 extendScalarToVector = false;
2515 }
2516 else
2517 writeBinaryOp = spirv::WriteIMul;
2518 break;
2519 case EOpVectorTimesMatrix:
2520 case EOpVectorTimesMatrixAssign:
2521 writeBinaryOp = spirv::WriteVectorTimesMatrix;
2522 operateOnColumns = false;
2523 break;
2524 case EOpMatrixTimesVector:
2525 writeBinaryOp = spirv::WriteMatrixTimesVector;
2526 operateOnColumns = false;
2527 break;
2528 case EOpMatrixTimesScalar:
2529 case EOpMatrixTimesScalarAssign:
2530 writeBinaryOp = spirv::WriteMatrixTimesScalar;
2531 binarySwapOperands = secondChild->getType().isMatrix();
2532 operateOnColumns = false;
2533 extendScalarToVector = false;
2534 break;
2535 case EOpMatrixTimesMatrix:
2536 case EOpMatrixTimesMatrixAssign:
2537 writeBinaryOp = spirv::WriteMatrixTimesMatrix;
2538 operateOnColumns = false;
2539 break;
2540
2541 case EOpLogicalOr:
2542 ASSERT(!IsShortCircuitNeeded(node));
2543 extendScalarToVector = false;
2544 writeBinaryOp = spirv::WriteLogicalOr;
2545 break;
2546 case EOpLogicalXor:
2547 extendScalarToVector = false;
2548 writeBinaryOp = spirv::WriteLogicalNotEqual;
2549 break;
2550 case EOpLogicalAnd:
2551 ASSERT(!IsShortCircuitNeeded(node));
2552 extendScalarToVector = false;
2553 writeBinaryOp = spirv::WriteLogicalAnd;
2554 break;
2555
2556 case EOpBitShiftLeft:
2557 case EOpBitShiftLeftAssign:
2558 writeBinaryOp = spirv::WriteShiftLeftLogical;
2559 break;
2560 case EOpBitShiftRight:
2561 case EOpBitShiftRightAssign:
2562 if (isUnsigned)
2563 writeBinaryOp = spirv::WriteShiftRightLogical;
2564 else
2565 writeBinaryOp = spirv::WriteShiftRightArithmetic;
2566 break;
2567 case EOpBitwiseAnd:
2568 case EOpBitwiseAndAssign:
2569 writeBinaryOp = spirv::WriteBitwiseAnd;
2570 break;
2571 case EOpBitwiseXor:
2572 case EOpBitwiseXorAssign:
2573 writeBinaryOp = spirv::WriteBitwiseXor;
2574 break;
2575 case EOpBitwiseOr:
2576 case EOpBitwiseOrAssign:
2577 writeBinaryOp = spirv::WriteBitwiseOr;
2578 break;
2579
2580 case EOpRadians:
2581 extendedInst = spv::GLSLstd450Radians;
2582 break;
2583 case EOpDegrees:
2584 extendedInst = spv::GLSLstd450Degrees;
2585 break;
2586 case EOpSin:
2587 extendedInst = spv::GLSLstd450Sin;
2588 break;
2589 case EOpCos:
2590 extendedInst = spv::GLSLstd450Cos;
2591 break;
2592 case EOpTan:
2593 extendedInst = spv::GLSLstd450Tan;
2594 break;
2595 case EOpAsin:
2596 extendedInst = spv::GLSLstd450Asin;
2597 break;
2598 case EOpAcos:
2599 extendedInst = spv::GLSLstd450Acos;
2600 break;
2601 case EOpAtan:
2602 extendedInst = childCount == 1 ? spv::GLSLstd450Atan : spv::GLSLstd450Atan2;
2603 break;
2604 case EOpSinh:
2605 extendedInst = spv::GLSLstd450Sinh;
2606 break;
2607 case EOpCosh:
2608 extendedInst = spv::GLSLstd450Cosh;
2609 break;
2610 case EOpTanh:
2611 extendedInst = spv::GLSLstd450Tanh;
2612 break;
2613 case EOpAsinh:
2614 extendedInst = spv::GLSLstd450Asinh;
2615 break;
2616 case EOpAcosh:
2617 extendedInst = spv::GLSLstd450Acosh;
2618 break;
2619 case EOpAtanh:
2620 extendedInst = spv::GLSLstd450Atanh;
2621 break;
2622
2623 case EOpPow:
2624 extendedInst = spv::GLSLstd450Pow;
2625 break;
2626 case EOpExp:
2627 extendedInst = spv::GLSLstd450Exp;
2628 break;
2629 case EOpLog:
2630 extendedInst = spv::GLSLstd450Log;
2631 break;
2632 case EOpExp2:
2633 extendedInst = spv::GLSLstd450Exp2;
2634 break;
2635 case EOpLog2:
2636 extendedInst = spv::GLSLstd450Log2;
2637 break;
2638 case EOpSqrt:
2639 extendedInst = spv::GLSLstd450Sqrt;
2640 break;
2641 case EOpInversesqrt:
2642 extendedInst = spv::GLSLstd450InverseSqrt;
2643 break;
2644
2645 case EOpAbs:
2646 if (isFloat)
2647 extendedInst = spv::GLSLstd450FAbs;
2648 else
2649 extendedInst = spv::GLSLstd450SAbs;
2650 break;
2651 case EOpSign:
2652 if (isFloat)
2653 extendedInst = spv::GLSLstd450FSign;
2654 else
2655 extendedInst = spv::GLSLstd450SSign;
2656 break;
2657 case EOpFloor:
2658 extendedInst = spv::GLSLstd450Floor;
2659 break;
2660 case EOpTrunc:
2661 extendedInst = spv::GLSLstd450Trunc;
2662 break;
2663 case EOpRound:
2664 extendedInst = spv::GLSLstd450Round;
2665 break;
2666 case EOpRoundEven:
2667 extendedInst = spv::GLSLstd450RoundEven;
2668 break;
2669 case EOpCeil:
2670 extendedInst = spv::GLSLstd450Ceil;
2671 break;
2672 case EOpFract:
2673 extendedInst = spv::GLSLstd450Fract;
2674 break;
2675 case EOpMod:
2676 if (isFloat)
2677 writeBinaryOp = spirv::WriteFMod;
2678 else if (isUnsigned)
2679 writeBinaryOp = spirv::WriteUMod;
2680 else
2681 writeBinaryOp = spirv::WriteSMod;
2682 break;
2683 case EOpMin:
2684 if (isFloat)
2685 extendedInst = spv::GLSLstd450FMin;
2686 else if (isUnsigned)
2687 extendedInst = spv::GLSLstd450UMin;
2688 else
2689 extendedInst = spv::GLSLstd450SMin;
2690 break;
2691 case EOpMax:
2692 if (isFloat)
2693 extendedInst = spv::GLSLstd450FMax;
2694 else if (isUnsigned)
2695 extendedInst = spv::GLSLstd450UMax;
2696 else
2697 extendedInst = spv::GLSLstd450SMax;
2698 break;
2699 case EOpClamp:
2700 if (isFloat)
2701 extendedInst = spv::GLSLstd450FClamp;
2702 else if (isUnsigned)
2703 extendedInst = spv::GLSLstd450UClamp;
2704 else
2705 extendedInst = spv::GLSLstd450SClamp;
2706 break;
2707 case EOpMix:
2708 if (node->getChildNode(childCount - 1)->getAsTyped()->getType().getBasicType() ==
2709 EbtBool)
2710 {
2711 writeTernaryOp = spirv::WriteSelect;
2712 }
2713 else
2714 {
2715 ASSERT(isFloat);
2716 extendedInst = spv::GLSLstd450FMix;
2717 }
2718 break;
2719 case EOpStep:
2720 extendedInst = spv::GLSLstd450Step;
2721 break;
2722 case EOpSmoothstep:
2723 extendedInst = spv::GLSLstd450SmoothStep;
2724 break;
2725 case EOpModf:
2726 extendedInst = spv::GLSLstd450ModfStruct;
2727 lvalueCount = 1;
2728 break;
2729 case EOpIsnan:
2730 writeUnaryOp = spirv::WriteIsNan;
2731 break;
2732 case EOpIsinf:
2733 writeUnaryOp = spirv::WriteIsInf;
2734 break;
2735 case EOpFloatBitsToInt:
2736 case EOpFloatBitsToUint:
2737 case EOpIntBitsToFloat:
2738 case EOpUintBitsToFloat:
2739 writeUnaryOp = spirv::WriteBitcast;
2740 break;
2741 case EOpFma:
2742 extendedInst = spv::GLSLstd450Fma;
2743 break;
2744 case EOpFrexp:
2745 extendedInst = spv::GLSLstd450FrexpStruct;
2746 lvalueCount = 1;
2747 break;
2748 case EOpLdexp:
2749 extendedInst = spv::GLSLstd450Ldexp;
2750 break;
2751 case EOpPackSnorm2x16:
2752 extendedInst = spv::GLSLstd450PackSnorm2x16;
2753 break;
2754 case EOpPackUnorm2x16:
2755 extendedInst = spv::GLSLstd450PackUnorm2x16;
2756 break;
2757 case EOpPackHalf2x16:
2758 extendedInst = spv::GLSLstd450PackHalf2x16;
2759 break;
2760 case EOpUnpackSnorm2x16:
2761 extendedInst = spv::GLSLstd450UnpackSnorm2x16;
2762 extendScalarToVector = false;
2763 break;
2764 case EOpUnpackUnorm2x16:
2765 extendedInst = spv::GLSLstd450UnpackUnorm2x16;
2766 extendScalarToVector = false;
2767 break;
2768 case EOpUnpackHalf2x16:
2769 extendedInst = spv::GLSLstd450UnpackHalf2x16;
2770 extendScalarToVector = false;
2771 break;
2772 case EOpPackUnorm4x8:
2773 extendedInst = spv::GLSLstd450PackUnorm4x8;
2774 break;
2775 case EOpPackSnorm4x8:
2776 extendedInst = spv::GLSLstd450PackSnorm4x8;
2777 break;
2778 case EOpUnpackUnorm4x8:
2779 extendedInst = spv::GLSLstd450UnpackUnorm4x8;
2780 extendScalarToVector = false;
2781 break;
2782 case EOpUnpackSnorm4x8:
2783 extendedInst = spv::GLSLstd450UnpackSnorm4x8;
2784 extendScalarToVector = false;
2785 break;
2786
2787 case EOpLength:
2788 extendedInst = spv::GLSLstd450Length;
2789 break;
2790 case EOpDistance:
2791 extendedInst = spv::GLSLstd450Distance;
2792 break;
2793 case EOpDot:
2794 // Use normal multiplication for scalars.
2795 if (firstOperandType.isScalar())
2796 {
2797 if (isFloat)
2798 writeBinaryOp = spirv::WriteFMul;
2799 else
2800 writeBinaryOp = spirv::WriteIMul;
2801 }
2802 else
2803 {
2804 writeBinaryOp = spirv::WriteDot;
2805 }
2806 break;
2807 case EOpCross:
2808 extendedInst = spv::GLSLstd450Cross;
2809 break;
2810 case EOpNormalize:
2811 extendedInst = spv::GLSLstd450Normalize;
2812 break;
2813 case EOpFaceforward:
2814 extendedInst = spv::GLSLstd450FaceForward;
2815 break;
2816 case EOpReflect:
2817 extendedInst = spv::GLSLstd450Reflect;
2818 break;
2819 case EOpRefract:
2820 extendedInst = spv::GLSLstd450Refract;
2821 extendScalarToVector = false;
2822 break;
2823
2824 case EOpOuterProduct:
2825 writeBinaryOp = spirv::WriteOuterProduct;
2826 break;
2827 case EOpTranspose:
2828 writeUnaryOp = spirv::WriteTranspose;
2829 break;
2830 case EOpDeterminant:
2831 extendedInst = spv::GLSLstd450Determinant;
2832 break;
2833 case EOpInverse:
2834 extendedInst = spv::GLSLstd450MatrixInverse;
2835 break;
2836
2837 case EOpAny:
2838 writeUnaryOp = spirv::WriteAny;
2839 break;
2840 case EOpAll:
2841 writeUnaryOp = spirv::WriteAll;
2842 break;
2843
2844 case EOpBitfieldExtract:
2845 if (isUnsigned)
2846 writeTernaryOp = spirv::WriteBitFieldUExtract;
2847 else
2848 writeTernaryOp = spirv::WriteBitFieldSExtract;
2849 break;
2850 case EOpBitfieldInsert:
2851 writeQuaternaryOp = spirv::WriteBitFieldInsert;
2852 break;
2853 case EOpBitfieldReverse:
2854 writeUnaryOp = spirv::WriteBitReverse;
2855 break;
2856 case EOpBitCount:
2857 writeUnaryOp = spirv::WriteBitCount;
2858 break;
2859 case EOpFindLSB:
2860 extendedInst = spv::GLSLstd450FindILsb;
2861 break;
2862 case EOpFindMSB:
2863 if (isUnsigned)
2864 extendedInst = spv::GLSLstd450FindUMsb;
2865 else
2866 extendedInst = spv::GLSLstd450FindSMsb;
2867 break;
2868 case EOpUaddCarry:
2869 writeBinaryOp = spirv::WriteIAddCarry;
2870 lvalueCount = 1;
2871 break;
2872 case EOpUsubBorrow:
2873 writeBinaryOp = spirv::WriteISubBorrow;
2874 lvalueCount = 1;
2875 break;
2876 case EOpUmulExtended:
2877 writeBinaryOp = spirv::WriteUMulExtended;
2878 lvalueCount = 2;
2879 break;
2880 case EOpImulExtended:
2881 writeBinaryOp = spirv::WriteSMulExtended;
2882 lvalueCount = 2;
2883 break;
2884
2885 case EOpRgb_2_yuv:
2886 case EOpYuv_2_rgb:
2887 // These built-ins are emulated, and shouldn't be encountered at this point.
2888 UNREACHABLE();
2889 break;
2890
2891 case EOpDFdx:
2892 writeUnaryOp = spirv::WriteDPdx;
2893 break;
2894 case EOpDFdy:
2895 writeUnaryOp = spirv::WriteDPdy;
2896 break;
2897 case EOpFwidth:
2898 writeUnaryOp = spirv::WriteFwidth;
2899 break;
2900
2901 default:
2902 UNREACHABLE();
2903 }
2904
2905 // Load the parameters.
2906 spirv::IdRefList parameterTypeIds;
2907 spirv::IdRefList parameters = loadAllParams(node, lvalueCount, ¶meterTypeIds);
2908
2909 if (isIncrementOrDecrement)
2910 {
2911 // ++ and -- are implemented with binary add and subtract, so add an implicit parameter with
2912 // size vecN(1).
2913 const uint8_t vecSize = firstOperandType.isMatrix() ? firstOperandType.getRows()
2914 : firstOperandType.getNominalSize();
2915 const spirv::IdRef one =
2916 isFloat ? mBuilder.getVecConstant(1, vecSize) : mBuilder.getIvecConstant(1, vecSize);
2917 parameters.push_back(one);
2918 }
2919
2920 const SpirvDecorations decorations =
2921 mBuilder.getArithmeticDecorations(node->getType(), node->isPrecise(), op);
2922 spirv::IdRef result;
2923 if (node->getType().getBasicType() != EbtVoid)
2924 {
2925 result = mBuilder.getNewId(decorations);
2926 }
2927
2928 // In the case of modf, frexp, uaddCarry, usubBorrow, umulExtended and imulExtended, the SPIR-V
2929 // result is expected to be a struct instead.
2930 spirv::IdRef builtInResultTypeId = resultTypeId;
2931 spirv::IdRef builtInResult;
2932 if (lvalueCount > 0)
2933 {
2934 builtInResultTypeId = makeBuiltInOutputStructType(node, lvalueCount);
2935 builtInResult = mBuilder.getNewId({});
2936 }
2937 else
2938 {
2939 builtInResult = result;
2940 }
2941
2942 if (operateOnColumns)
2943 {
2944 // If negating a matrix, multiplying or comparing them, do that column by column.
2945 // Matrix-scalar operations (add, sub, mod etc) turn the scalar into a vector before
2946 // operating on the column.
2947 spirv::IdRefList columnIds;
2948
2949 const SpirvDecorations operandDecorations = mBuilder.getDecorations(firstOperandType);
2950
2951 const TType &matrixType =
2952 firstOperandType.isMatrix() ? firstOperandType : secondChild->getType();
2953
2954 const spirv::IdRef columnTypeId =
2955 mBuilder.getBasicTypeId(matrixType.getBasicType(), matrixType.getRows());
2956
2957 if (binarySwapOperands)
2958 {
2959 std::swap(parameters[0], parameters[1]);
2960 }
2961
2962 if (extendScalarToVector)
2963 {
2964 extendScalarParamsToVector(node, columnTypeId, ¶meters);
2965 }
2966
2967 // Extract and apply the operator to each column.
2968 for (uint8_t columnIndex = 0; columnIndex < matrixType.getCols(); ++columnIndex)
2969 {
2970 spirv::IdRef columnIdA = parameters[0];
2971 if (firstOperandType.isMatrix())
2972 {
2973 columnIdA = mBuilder.getNewId(operandDecorations);
2974 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2975 columnIdA, parameters[0],
2976 {spirv::LiteralInteger(columnIndex)});
2977 }
2978
2979 columnIds.push_back(mBuilder.getNewId(decorations));
2980
2981 if (writeUnaryOp)
2982 {
2983 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2984 columnIds.back(), columnIdA);
2985 }
2986 else
2987 {
2988 ASSERT(writeBinaryOp);
2989
2990 spirv::IdRef columnIdB = parameters[1];
2991 if (secondChild != nullptr && secondChild->getType().isMatrix())
2992 {
2993 columnIdB = mBuilder.getNewId(operandDecorations);
2994 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
2995 columnTypeId, columnIdB, parameters[1],
2996 {spirv::LiteralInteger(columnIndex)});
2997 }
2998
2999 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3000 columnIds.back(), columnIdA, columnIdB);
3001 }
3002 }
3003
3004 // Construct the result.
3005 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3006 builtInResult, columnIds);
3007 }
3008 else if (writeUnaryOp)
3009 {
3010 ASSERT(parameters.size() == 1);
3011 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3012 parameters[0]);
3013 }
3014 else if (writeBinaryOp)
3015 {
3016 ASSERT(parameters.size() == 2);
3017
3018 if (extendScalarToVector)
3019 {
3020 extendScalarParamsToVector(node, builtInResultTypeId, ¶meters);
3021 }
3022
3023 if (binarySwapOperands)
3024 {
3025 std::swap(parameters[0], parameters[1]);
3026 }
3027
3028 if (binaryInvertSecondParameter)
3029 {
3030 const spirv::IdRef one = mBuilder.getFloatConstant(1);
3031 const spirv::IdRef invertedParam = mBuilder.getNewId(
3032 mBuilder.getArithmeticDecorations(secondChild->getType(), node->isPrecise(), op));
3033 spirv::WriteFDiv(mBuilder.getSpirvCurrentFunctionBlock(), parameterTypeIds.back(),
3034 invertedParam, one, parameters[1]);
3035 parameters[1] = invertedParam;
3036 }
3037
3038 // Write the operation that combines the left and right values.
3039 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3040 parameters[0], parameters[1]);
3041 }
3042 else if (writeTernaryOp)
3043 {
3044 ASSERT(parameters.size() == 3);
3045
3046 // mix(a, b, bool) is the same as bool ? b : a;
3047 if (op == EOpMix)
3048 {
3049 std::swap(parameters[0], parameters[2]);
3050 }
3051
3052 writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3053 parameters[0], parameters[1], parameters[2]);
3054 }
3055 else if (writeQuaternaryOp)
3056 {
3057 ASSERT(parameters.size() == 4);
3058
3059 writeQuaternaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3060 builtInResult, parameters[0], parameters[1], parameters[2],
3061 parameters[3]);
3062 }
3063 else
3064 {
3065 // It's an extended instruction.
3066 ASSERT(extendedInst != spv::GLSLstd450Bad);
3067
3068 if (extendScalarToVector)
3069 {
3070 extendScalarParamsToVector(node, builtInResultTypeId, ¶meters);
3071 }
3072
3073 spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3074 builtInResult, mBuilder.getExtInstImportIdStd(),
3075 spirv::LiteralExtInstInteger(extendedInst), parameters);
3076 }
3077
3078 // If it's an assignment, store the calculated value.
3079 if (IsAssignment(node->getOp()))
3080 {
3081 accessChainStore(&mNodeData[mNodeData.size() - childCount], builtInResult,
3082 firstOperandType);
3083 }
3084
3085 // If the operation returns a struct, load the lsb and msb and store them in result/out
3086 // parameters.
3087 if (lvalueCount > 0)
3088 {
3089 storeBuiltInStructOutputInParamsAndReturnValue(node, lvalueCount, builtInResult, result,
3090 resultTypeId);
3091 }
3092
3093 // For post increment/decrement, return the value of the parameter itself as the result.
3094 if (op == EOpPostIncrement || op == EOpPostDecrement)
3095 {
3096 result = parameters[0];
3097 }
3098
3099 return result;
3100 }
3101
createCompare(TIntermOperator * node,spirv::IdRef resultTypeId)3102 spirv::IdRef OutputSPIRVTraverser::createCompare(TIntermOperator *node, spirv::IdRef resultTypeId)
3103 {
3104 const TOperator op = node->getOp();
3105 TIntermTyped *operand = node->getChildNode(0)->getAsTyped();
3106 const TType &operandType = operand->getType();
3107
3108 const SpirvDecorations resultDecorations = mBuilder.getDecorations(node->getType());
3109 const SpirvDecorations operandDecorations = mBuilder.getDecorations(operandType);
3110
3111 // Load the left and right values.
3112 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3113 ASSERT(parameters.size() == 2);
3114
3115 // In GLSL, operators == and != can operate on the following:
3116 //
3117 // - scalars: There's a SPIR-V instruction for this,
3118 // - vectors: The same SPIR-V instruction as scalars is used here, but the result is reduced
3119 // with OpAll/OpAny for == and != respectively,
3120 // - matrices: Comparison must be done column by column and the result reduced,
3121 // - arrays: Comparison must be done on every array element and the result reduced,
3122 // - structs: Comparison must be done on each field and the result reduced.
3123 //
3124 // For the latter 3 cases, OpCompositeExtract is used to extract scalars and vectors out of the
3125 // more complex type, which is recursively traversed. The results are accumulated in a list
3126 // that is then reduced 4 by 4 elements until a single boolean is produced.
3127
3128 spirv::LiteralIntegerList currentAccessChain;
3129 spirv::IdRefList intermediateResults;
3130
3131 createCompareImpl(op, operandType, resultTypeId, parameters[0], parameters[1],
3132 operandDecorations, resultDecorations, ¤tAccessChain,
3133 &intermediateResults);
3134
3135 // Make sure the function correctly pushes and pops access chain indices.
3136 ASSERT(currentAccessChain.empty());
3137
3138 // Reduce the intermediate results.
3139 ASSERT(!intermediateResults.empty());
3140
3141 // The following code implements this algorithm, assuming N bools are to be reduced:
3142 //
3143 // Reduced To Reduce
3144 // {b1} {b2, b3, ..., bN} Initial state
3145 // Loop
3146 // {b1, b2, b3, b4} {b5, b6, ..., bN} Take up to 3 new bools
3147 // {r1} {b5, b6, ..., bN} Reduce it
3148 // Repeat
3149 //
3150 // In the end, a single value is left.
3151 size_t reducedCount = 0;
3152 spirv::IdRefList toReduce = {intermediateResults[reducedCount++]};
3153 while (reducedCount < intermediateResults.size())
3154 {
3155 // Take up to 3 new bools.
3156 size_t toTakeCount = std::min<size_t>(3, intermediateResults.size() - reducedCount);
3157 for (size_t i = 0; i < toTakeCount; ++i)
3158 {
3159 toReduce.push_back(intermediateResults[reducedCount++]);
3160 }
3161
3162 // Reduce them to one bool.
3163 const spirv::IdRef result = reduceBoolVector(op, toReduce, resultTypeId, resultDecorations);
3164
3165 // Replace the list of bools to reduce with the reduced one.
3166 toReduce.clear();
3167 toReduce.push_back(result);
3168 }
3169
3170 ASSERT(toReduce.size() == 1 && reducedCount == intermediateResults.size());
3171 return toReduce[0];
3172 }
3173
createAtomicBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3174 spirv::IdRef OutputSPIRVTraverser::createAtomicBuiltIn(TIntermOperator *node,
3175 spirv::IdRef resultTypeId)
3176 {
3177 const TType &operandType = node->getChildNode(0)->getAsTyped()->getType();
3178 const TBasicType operandBasicType = operandType.getBasicType();
3179 const bool isImage = IsImage(operandBasicType);
3180
3181 // Most atomic instructions are in the form of:
3182 //
3183 // %result = OpAtomicX %pointer Scope MemorySemantics %value
3184 //
3185 // OpAtomicCompareSwap is exceptionally different (note that compare and value are in different
3186 // order from GLSL):
3187 //
3188 // %result = OpAtomicCompareExchange %pointer
3189 // Scope MemorySemantics MemorySemantics
3190 // %value %comparator
3191 //
3192 // In all cases, the first parameter is the pointer, and the rest are rvalues.
3193 //
3194 // For images, OpImageTexelPointer is used to form a pointer to the texel on which the atomic
3195 // operation is being performed.
3196 const size_t parameterCount = node->getChildCount();
3197 size_t imagePointerParameterCount = 0;
3198 spirv::IdRef pointerId;
3199 spirv::IdRefList imagePointerParameters;
3200 spirv::IdRefList parameters;
3201
3202 if (isImage)
3203 {
3204 // One parameter for coordinates.
3205 ++imagePointerParameterCount;
3206 if (IsImageMS(operandBasicType))
3207 {
3208 // One parameter for samples.
3209 ++imagePointerParameterCount;
3210 }
3211 }
3212
3213 ASSERT(parameterCount >= 2 + imagePointerParameterCount);
3214
3215 pointerId = accessChainCollapse(&mNodeData[mNodeData.size() - parameterCount]);
3216 for (size_t paramIndex = 1; paramIndex < parameterCount; ++paramIndex)
3217 {
3218 NodeData ¶m = mNodeData[mNodeData.size() - parameterCount + paramIndex];
3219 const spirv::IdRef parameter = accessChainLoad(
3220 ¶m, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
3221
3222 // imageAtomic* built-ins have a few additional parameters right after the image. These are
3223 // kept separately for use with OpImageTexelPointer.
3224 if (paramIndex <= imagePointerParameterCount)
3225 {
3226 imagePointerParameters.push_back(parameter);
3227 }
3228 else
3229 {
3230 parameters.push_back(parameter);
3231 }
3232 }
3233
3234 // The scope of the operation is always Device as we don't enable the Vulkan memory model
3235 // extension.
3236 const spirv::IdScope scopeId = mBuilder.getUintConstant(spv::ScopeDevice);
3237
3238 // The memory semantics is always relaxed as we don't enable the Vulkan memory model extension.
3239 const spirv::IdMemorySemantics semanticsId =
3240 mBuilder.getUintConstant(spv::MemorySemanticsMaskNone);
3241
3242 WriteAtomicOp writeAtomicOp = nullptr;
3243
3244 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3245
3246 // Determine whether the operation is on ints or uints.
3247 const bool isUnsigned = isImage ? IsUIntImage(operandBasicType) : operandBasicType == EbtUInt;
3248
3249 // For images, convert the pointer to the image to a pointer to a texel in the image.
3250 if (isImage)
3251 {
3252 const spirv::IdRef texelTypePointerId =
3253 mBuilder.getTypePointerId(resultTypeId, spv::StorageClassImage);
3254 const spirv::IdRef texelPointerId = mBuilder.getNewId({});
3255
3256 const spirv::IdRef coordinate = imagePointerParameters[0];
3257 spirv::IdRef sample = imagePointerParameters.size() > 1 ? imagePointerParameters[1]
3258 : mBuilder.getUintConstant(0);
3259
3260 spirv::WriteImageTexelPointer(mBuilder.getSpirvCurrentFunctionBlock(), texelTypePointerId,
3261 texelPointerId, pointerId, coordinate, sample);
3262
3263 pointerId = texelPointerId;
3264 }
3265
3266 switch (node->getOp())
3267 {
3268 case EOpAtomicAdd:
3269 case EOpImageAtomicAdd:
3270 writeAtomicOp = spirv::WriteAtomicIAdd;
3271 break;
3272 case EOpAtomicMin:
3273 case EOpImageAtomicMin:
3274 writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMin : spirv::WriteAtomicSMin;
3275 break;
3276 case EOpAtomicMax:
3277 case EOpImageAtomicMax:
3278 writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMax : spirv::WriteAtomicSMax;
3279 break;
3280 case EOpAtomicAnd:
3281 case EOpImageAtomicAnd:
3282 writeAtomicOp = spirv::WriteAtomicAnd;
3283 break;
3284 case EOpAtomicOr:
3285 case EOpImageAtomicOr:
3286 writeAtomicOp = spirv::WriteAtomicOr;
3287 break;
3288 case EOpAtomicXor:
3289 case EOpImageAtomicXor:
3290 writeAtomicOp = spirv::WriteAtomicXor;
3291 break;
3292 case EOpAtomicExchange:
3293 case EOpImageAtomicExchange:
3294 writeAtomicOp = spirv::WriteAtomicExchange;
3295 break;
3296 case EOpAtomicCompSwap:
3297 case EOpImageAtomicCompSwap:
3298 // Generate this special instruction right here and early out. Note again that the
3299 // value and compare parameters of OpAtomicCompareExchange are in the opposite order
3300 // from GLSL.
3301 ASSERT(parameters.size() == 2);
3302 spirv::WriteAtomicCompareExchange(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3303 result, pointerId, scopeId, semanticsId, semanticsId,
3304 parameters[1], parameters[0]);
3305 return result;
3306 default:
3307 UNREACHABLE();
3308 }
3309
3310 // Write the instruction.
3311 ASSERT(parameters.size() == 1);
3312 writeAtomicOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, pointerId, scopeId,
3313 semanticsId, parameters[0]);
3314
3315 return result;
3316 }
3317
createImageTextureBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3318 spirv::IdRef OutputSPIRVTraverser::createImageTextureBuiltIn(TIntermOperator *node,
3319 spirv::IdRef resultTypeId)
3320 {
3321 const TOperator op = node->getOp();
3322 const TFunction *function = node->getAsAggregate()->getFunction();
3323 const TType &samplerType = function->getParam(0)->getType();
3324 const TBasicType samplerBasicType = samplerType.getBasicType();
3325
3326 // Load the parameters.
3327 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3328
3329 // GLSL texture* and image* built-ins map to the following SPIR-V instructions. Some of these
3330 // instructions take a "sampled image" while the others take the image itself. In these
3331 // functions, the image, coordinates and Dref (for shadow sampling) are specified as positional
3332 // parameters while the rest are bundled in a list of image operands.
3333 //
3334 // Image operations that query:
3335 //
3336 // - OpImageQuerySizeLod
3337 // - OpImageQuerySize
3338 // - OpImageQueryLod <-- sampled image
3339 // - OpImageQueryLevels
3340 // - OpImageQuerySamples
3341 //
3342 // Image operations that read/write:
3343 //
3344 // - OpImageSampleImplicitLod <-- sampled image
3345 // - OpImageSampleExplicitLod <-- sampled image
3346 // - OpImageSampleDrefImplicitLod <-- sampled image
3347 // - OpImageSampleDrefExplicitLod <-- sampled image
3348 // - OpImageSampleProjImplicitLod <-- sampled image
3349 // - OpImageSampleProjExplicitLod <-- sampled image
3350 // - OpImageSampleProjDrefImplicitLod <-- sampled image
3351 // - OpImageSampleProjDrefExplicitLod <-- sampled image
3352 // - OpImageFetch
3353 // - OpImageGather <-- sampled image
3354 // - OpImageDrefGather <-- sampled image
3355 // - OpImageRead
3356 // - OpImageWrite
3357 //
3358 // The additional image parameters are:
3359 //
3360 // - Bias: Only used with ImplicitLod.
3361 // - Lod: Only used with ExplicitLod.
3362 // - Grad: 2x operands; dx and dy. Only used with ExplicitLod.
3363 // - ConstOffset: Constant offset added to coordinates of OpImage*Gather.
3364 // - Offset: Non-constant offset added to coordinates of OpImage*Gather.
3365 // - ConstOffsets: Constant offsets added to coordinates of OpImage*Gather.
3366 // - Sample: Only used with OpImageFetch, OpImageRead and OpImageWrite.
3367 //
3368 // Where GLSL's built-in takes a sampler but SPIR-V expects an image, OpImage can be used to get
3369 // the SPIR-V image out of a SPIR-V sampled image.
3370
3371 // The first parameter, which is either a sampled image or an image. Some GLSL built-ins
3372 // receive a sampled image but their SPIR-V equivalent expects an image. OpImage is used in
3373 // that case.
3374 spirv::IdRef image = parameters[0];
3375 bool extractImageFromSampledImage = false;
3376
3377 // The argument index for different possible parameters. 0 indicates that the argument is
3378 // unused. Coordinates are usually at index 1, so it's pre-initialized.
3379 size_t coordinatesIndex = 1;
3380 size_t biasIndex = 0;
3381 size_t lodIndex = 0;
3382 size_t compareIndex = 0;
3383 size_t dPdxIndex = 0;
3384 size_t dPdyIndex = 0;
3385 size_t offsetIndex = 0;
3386 size_t offsetsIndex = 0;
3387 size_t gatherComponentIndex = 0;
3388 size_t sampleIndex = 0;
3389 size_t dataIndex = 0;
3390
3391 // Whether this is a Dref variant of a sample call.
3392 bool isDref = IsShadowSampler(samplerBasicType);
3393 // Whether this is a Proj variant of a sample call.
3394 bool isProj = false;
3395
3396 // The SPIR-V op used to implement the built-in. For OpImageSample* instructions,
3397 // OpImageSampleImplicitLod is initially specified, which is later corrected based on |isDref|
3398 // and |isProj|.
3399 spv::Op spirvOp = BuiltInGroup::IsTexture(op) ? spv::OpImageSampleImplicitLod : spv::OpNop;
3400
3401 // Organize the parameters and decide the SPIR-V Op to use.
3402 switch (op)
3403 {
3404 case EOpTexture2D:
3405 case EOpTextureCube:
3406 case EOpTexture3D:
3407 case EOpShadow2DEXT:
3408 case EOpTexture2DRect:
3409 case EOpTextureVideoWEBGL:
3410 case EOpTexture:
3411
3412 case EOpTexture2DBias:
3413 case EOpTextureCubeBias:
3414 case EOpTexture3DBias:
3415 case EOpTextureBias:
3416
3417 // For shadow cube arrays, the compare value is specified through an additional
3418 // parameter, while for the rest is taken out of the coordinates.
3419 if (function->getParamCount() == 3)
3420 {
3421 if (samplerBasicType == EbtSamplerCubeArrayShadow)
3422 {
3423 compareIndex = 2;
3424 }
3425 else
3426 {
3427 biasIndex = 2;
3428 }
3429 }
3430 else if (function->getParamCount() == 4 &&
3431 samplerBasicType == EbtSamplerCubeArrayShadow)
3432 {
3433 compareIndex = 2;
3434 biasIndex = 3;
3435 }
3436 break;
3437
3438 case EOpTexture2DProj:
3439 case EOpTexture3DProj:
3440 case EOpShadow2DProjEXT:
3441 case EOpTexture2DRectProj:
3442 case EOpTextureProj:
3443
3444 case EOpTexture2DProjBias:
3445 case EOpTexture3DProjBias:
3446 case EOpTextureProjBias:
3447
3448 isProj = true;
3449 if (function->getParamCount() == 3)
3450 {
3451 biasIndex = 2;
3452 }
3453 break;
3454
3455 case EOpTexture3DLod:
3456
3457 case EOpTexture2DLodVS:
3458 case EOpTextureCubeLodVS:
3459
3460 case EOpTexture2DLodEXTFS:
3461 case EOpTextureCubeLodEXTFS:
3462 case EOpTextureLod:
3463
3464 if (samplerBasicType == EbtSamplerCubeArrayShadow)
3465 {
3466 ASSERT(function->getParamCount() == 4);
3467 compareIndex = 2;
3468 lodIndex = 3;
3469 }
3470 else
3471 {
3472 ASSERT(function->getParamCount() == 3);
3473 lodIndex = 2;
3474 }
3475 break;
3476
3477 case EOpTexture3DProjLod:
3478
3479 case EOpTexture2DProjLodVS:
3480
3481 case EOpTexture2DProjLodEXTFS:
3482 case EOpTextureProjLod:
3483
3484 ASSERT(function->getParamCount() == 3);
3485 isProj = true;
3486 lodIndex = 2;
3487 break;
3488
3489 case EOpTexelFetch:
3490 case EOpTexelFetchOffset:
3491 // texelFetch has the following forms:
3492 //
3493 // - texelFetch(sampler, P);
3494 // - texelFetch(sampler, P, lod);
3495 // - texelFetch(samplerMS, P, sample);
3496 //
3497 // texelFetchOffset has an additional offset parameter at the end.
3498 //
3499 // In SPIR-V, OpImageFetch is used which operates on the image itself.
3500 spirvOp = spv::OpImageFetch;
3501 extractImageFromSampledImage = true;
3502
3503 if (IsSamplerMS(samplerBasicType))
3504 {
3505 ASSERT(function->getParamCount() == 3);
3506 sampleIndex = 2;
3507 }
3508 else if (function->getParamCount() >= 3)
3509 {
3510 lodIndex = 2;
3511 }
3512 if (op == EOpTexelFetchOffset)
3513 {
3514 offsetIndex = function->getParamCount() - 1;
3515 }
3516 break;
3517
3518 case EOpTexture2DGradEXT:
3519 case EOpTextureCubeGradEXT:
3520 case EOpTextureGrad:
3521
3522 ASSERT(function->getParamCount() == 4);
3523 dPdxIndex = 2;
3524 dPdyIndex = 3;
3525 break;
3526
3527 case EOpTexture2DProjGradEXT:
3528 case EOpTextureProjGrad:
3529
3530 ASSERT(function->getParamCount() == 4);
3531 isProj = true;
3532 dPdxIndex = 2;
3533 dPdyIndex = 3;
3534 break;
3535
3536 case EOpTextureOffset:
3537 case EOpTextureOffsetBias:
3538
3539 ASSERT(function->getParamCount() >= 3);
3540 offsetIndex = 2;
3541 if (function->getParamCount() == 4)
3542 {
3543 biasIndex = 3;
3544 }
3545 break;
3546
3547 case EOpTextureProjOffset:
3548 case EOpTextureProjOffsetBias:
3549
3550 ASSERT(function->getParamCount() >= 3);
3551 isProj = true;
3552 offsetIndex = 2;
3553 if (function->getParamCount() == 4)
3554 {
3555 biasIndex = 3;
3556 }
3557 break;
3558
3559 case EOpTextureLodOffset:
3560
3561 ASSERT(function->getParamCount() == 4);
3562 lodIndex = 2;
3563 offsetIndex = 3;
3564 break;
3565
3566 case EOpTextureProjLodOffset:
3567
3568 ASSERT(function->getParamCount() == 4);
3569 isProj = true;
3570 lodIndex = 2;
3571 offsetIndex = 3;
3572 break;
3573
3574 case EOpTextureGradOffset:
3575
3576 ASSERT(function->getParamCount() == 5);
3577 dPdxIndex = 2;
3578 dPdyIndex = 3;
3579 offsetIndex = 4;
3580 break;
3581
3582 case EOpTextureProjGradOffset:
3583
3584 ASSERT(function->getParamCount() == 5);
3585 isProj = true;
3586 dPdxIndex = 2;
3587 dPdyIndex = 3;
3588 offsetIndex = 4;
3589 break;
3590
3591 case EOpTextureGather:
3592
3593 // For shadow textures, refZ (same as Dref) is specified as the last argument.
3594 // Otherwise a component may be specified which defaults to 0 if not specified.
3595 spirvOp = spv::OpImageGather;
3596 if (isDref)
3597 {
3598 ASSERT(function->getParamCount() == 3);
3599 compareIndex = 2;
3600 }
3601 else if (function->getParamCount() == 3)
3602 {
3603 gatherComponentIndex = 2;
3604 }
3605 break;
3606
3607 case EOpTextureGatherOffset:
3608 case EOpTextureGatherOffsetComp:
3609
3610 case EOpTextureGatherOffsets:
3611 case EOpTextureGatherOffsetsComp:
3612
3613 // textureGatherOffset and textureGatherOffsets have the following forms:
3614 //
3615 // - texelGatherOffset*(sampler, P, offset*);
3616 // - texelGatherOffset*(sampler, P, offset*, component);
3617 // - texelGatherOffset*(sampler, P, refZ, offset*);
3618 //
3619 spirvOp = spv::OpImageGather;
3620 if (isDref)
3621 {
3622 ASSERT(function->getParamCount() == 4);
3623 compareIndex = 2;
3624 }
3625 else if (function->getParamCount() == 4)
3626 {
3627 gatherComponentIndex = 3;
3628 }
3629
3630 ASSERT(function->getParamCount() >= 3);
3631 if (BuiltInGroup::IsTextureGatherOffset(op))
3632 {
3633 offsetIndex = isDref ? 3 : 2;
3634 }
3635 else
3636 {
3637 offsetsIndex = isDref ? 3 : 2;
3638 }
3639 break;
3640
3641 case EOpImageStore:
3642 // imageStore has the following forms:
3643 //
3644 // - imageStore(image, P, data);
3645 // - imageStore(imageMS, P, sample, data);
3646 //
3647 spirvOp = spv::OpImageWrite;
3648 if (IsSamplerMS(samplerBasicType))
3649 {
3650 ASSERT(function->getParamCount() == 4);
3651 sampleIndex = 2;
3652 dataIndex = 3;
3653 }
3654 else
3655 {
3656 ASSERT(function->getParamCount() == 3);
3657 dataIndex = 2;
3658 }
3659 break;
3660
3661 case EOpImageLoad:
3662 // imageStore has the following forms:
3663 //
3664 // - imageLoad(image, P);
3665 // - imageLoad(imageMS, P, sample);
3666 //
3667 spirvOp = spv::OpImageRead;
3668 if (IsSamplerMS(samplerBasicType))
3669 {
3670 ASSERT(function->getParamCount() == 3);
3671 sampleIndex = 2;
3672 }
3673 else
3674 {
3675 ASSERT(function->getParamCount() == 2);
3676 }
3677 break;
3678
3679 // Queries:
3680 case EOpTextureSize:
3681 case EOpImageSize:
3682 // textureSize has the following forms:
3683 //
3684 // - textureSize(sampler);
3685 // - textureSize(sampler, lod);
3686 //
3687 // while imageSize has only one form:
3688 //
3689 // - imageSize(image);
3690 //
3691 extractImageFromSampledImage = true;
3692 if (function->getParamCount() == 2)
3693 {
3694 spirvOp = spv::OpImageQuerySizeLod;
3695 lodIndex = 1;
3696 }
3697 else
3698 {
3699 spirvOp = spv::OpImageQuerySize;
3700 }
3701 // No coordinates parameter.
3702 coordinatesIndex = 0;
3703 // No dref parameter.
3704 isDref = false;
3705 break;
3706
3707 default:
3708 UNREACHABLE();
3709 }
3710
3711 // If an implicit-lod instruction is used outside a fragment shader, change that to an explicit
3712 // one as they are not allowed in SPIR-V outside fragment shaders.
3713 const bool noLodSupport = IsSamplerBuffer(samplerBasicType) ||
3714 IsImageBuffer(samplerBasicType) || IsSamplerMS(samplerBasicType) ||
3715 IsImageMS(samplerBasicType);
3716 const bool makeLodExplicit =
3717 mCompiler->getShaderType() != GL_FRAGMENT_SHADER && lodIndex == 0 && dPdxIndex == 0 &&
3718 !noLodSupport && (spirvOp == spv::OpImageSampleImplicitLod || spirvOp == spv::OpImageFetch);
3719
3720 // Apply any necessary fix up.
3721
3722 if (extractImageFromSampledImage && IsSampler(samplerBasicType))
3723 {
3724 // Get the (non-sampled) image type.
3725 SpirvType imageType = mBuilder.getSpirvType(samplerType, {});
3726 ASSERT(!imageType.isSamplerBaseImage);
3727 imageType.isSamplerBaseImage = true;
3728 const spirv::IdRef extractedImageTypeId = mBuilder.getSpirvTypeData(imageType, nullptr).id;
3729
3730 // Use OpImage to get the image out of the sampled image.
3731 const spirv::IdRef extractedImage = mBuilder.getNewId({});
3732 spirv::WriteImage(mBuilder.getSpirvCurrentFunctionBlock(), extractedImageTypeId,
3733 extractedImage, image);
3734 image = extractedImage;
3735 }
3736
3737 // Gather operands as necessary.
3738
3739 // - Coordinates
3740 uint8_t coordinatesChannelCount = 0;
3741 spirv::IdRef coordinatesId;
3742 const TType *coordinatesType = nullptr;
3743 if (coordinatesIndex > 0)
3744 {
3745 coordinatesId = parameters[coordinatesIndex];
3746 coordinatesType = &node->getChildNode(coordinatesIndex)->getAsTyped()->getType();
3747 coordinatesChannelCount = coordinatesType->getNominalSize();
3748 }
3749
3750 // - Dref; either specified as a compare/refz argument (cube array, gather), or:
3751 // * coordinates.z for proj variants
3752 // * coordinates.<last> for others
3753 spirv::IdRef drefId;
3754 if (compareIndex > 0)
3755 {
3756 drefId = parameters[compareIndex];
3757 }
3758 else if (isDref)
3759 {
3760 // Get the component index
3761 ASSERT(coordinatesChannelCount > 0);
3762 uint8_t drefComponent = isProj ? 2 : coordinatesChannelCount - 1;
3763
3764 // Get the component type
3765 SpirvType drefSpirvType = mBuilder.getSpirvType(*coordinatesType, {});
3766 drefSpirvType.primarySize = 1;
3767 const spirv::IdRef drefTypeId = mBuilder.getSpirvTypeData(drefSpirvType, nullptr).id;
3768
3769 // Extract the dref component out of coordinates.
3770 drefId = mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3771 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), drefTypeId, drefId,
3772 coordinatesId, {spirv::LiteralInteger(drefComponent)});
3773 }
3774
3775 // - Gather component
3776 spirv::IdRef gatherComponentId;
3777 if (gatherComponentIndex > 0)
3778 {
3779 gatherComponentId = parameters[gatherComponentIndex];
3780 }
3781 else if (spirvOp == spv::OpImageGather)
3782 {
3783 // If comp is not specified, component 0 is taken as default.
3784 gatherComponentId = mBuilder.getIntConstant(0);
3785 }
3786
3787 // - Image write data
3788 spirv::IdRef dataId;
3789 if (dataIndex > 0)
3790 {
3791 dataId = parameters[dataIndex];
3792 }
3793
3794 // - Other operands
3795 spv::ImageOperandsMask operandsMask = spv::ImageOperandsMaskNone;
3796 spirv::IdRefList imageOperandsList;
3797
3798 if (biasIndex > 0)
3799 {
3800 operandsMask = operandsMask | spv::ImageOperandsBiasMask;
3801 imageOperandsList.push_back(parameters[biasIndex]);
3802 }
3803 if (lodIndex > 0)
3804 {
3805 operandsMask = operandsMask | spv::ImageOperandsLodMask;
3806 imageOperandsList.push_back(parameters[lodIndex]);
3807 }
3808 else if (makeLodExplicit)
3809 {
3810 // If the implicit-lod variant is used outside fragment shaders, switch to explicit and use
3811 // lod 0.
3812 operandsMask = operandsMask | spv::ImageOperandsLodMask;
3813 imageOperandsList.push_back(spirvOp == spv::OpImageFetch ? mBuilder.getUintConstant(0)
3814 : mBuilder.getFloatConstant(0));
3815 }
3816 if (dPdxIndex > 0)
3817 {
3818 ASSERT(dPdyIndex > 0);
3819 operandsMask = operandsMask | spv::ImageOperandsGradMask;
3820 imageOperandsList.push_back(parameters[dPdxIndex]);
3821 imageOperandsList.push_back(parameters[dPdyIndex]);
3822 }
3823 if (offsetIndex > 0)
3824 {
3825 // Non-const offsets require the ImageGatherExtended feature.
3826 if (node->getChildNode(offsetIndex)->getAsTyped()->hasConstantValue())
3827 {
3828 operandsMask = operandsMask | spv::ImageOperandsConstOffsetMask;
3829 }
3830 else
3831 {
3832 ASSERT(spirvOp == spv::OpImageGather);
3833
3834 operandsMask = operandsMask | spv::ImageOperandsOffsetMask;
3835 mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3836 }
3837 imageOperandsList.push_back(parameters[offsetIndex]);
3838 }
3839 if (offsetsIndex > 0)
3840 {
3841 ASSERT(node->getChildNode(offsetsIndex)->getAsTyped()->hasConstantValue());
3842
3843 operandsMask = operandsMask | spv::ImageOperandsConstOffsetsMask;
3844 mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3845 imageOperandsList.push_back(parameters[offsetsIndex]);
3846 }
3847 if (sampleIndex > 0)
3848 {
3849 operandsMask = operandsMask | spv::ImageOperandsSampleMask;
3850 imageOperandsList.push_back(parameters[sampleIndex]);
3851 }
3852
3853 const spv::ImageOperandsMask *imageOperands =
3854 imageOperandsList.empty() ? nullptr : &operandsMask;
3855
3856 // GLSL and SPIR-V are different in the way the projective component is specified:
3857 //
3858 // In GLSL:
3859 //
3860 // > The texture coordinates consumed from P, not including the last component of P, are divided
3861 // > by the last component of P.
3862 //
3863 // In SPIR-V, there's a similar language (division by last element), but with the following
3864 // added:
3865 //
3866 // > ... all unused components will appear after all used components.
3867 //
3868 // So for example for textureProj(sampler, vec4 P), the projective coordinates are P.xy/P.w,
3869 // where P.z is ignored. In SPIR-V instead that would be P.xy/P.z and P.w is ignored.
3870 //
3871 if (isProj)
3872 {
3873 uint8_t requiredChannelCount = coordinatesChannelCount;
3874 // texture*Proj* operate on the following parameters:
3875 //
3876 // - sampler2D, vec3 P
3877 // - sampler2D, vec4 P
3878 // - sampler2DRect, vec3 P
3879 // - sampler2DRect, vec4 P
3880 // - sampler3D, vec4 P
3881 // - sampler2DShadow, vec4 P
3882 // - sampler2DRectShadow, vec4 P
3883 //
3884 // Of these cases, only (sampler2D*, vec4 P) requires moving the proj channel from .w to the
3885 // appropriate location (.y for 1D and .z for 2D).
3886 if (IsSampler2D(samplerBasicType))
3887 {
3888 requiredChannelCount = 3;
3889 }
3890 if (requiredChannelCount != coordinatesChannelCount)
3891 {
3892 ASSERT(coordinatesChannelCount == 4);
3893
3894 // Get the component type
3895 SpirvType spirvType = mBuilder.getSpirvType(*coordinatesType, {});
3896 const spirv::IdRef coordinatesTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3897 spirvType.primarySize = 1;
3898 const spirv::IdRef channelTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3899
3900 // Extract the last component out of coordinates.
3901 const spirv::IdRef projChannelId =
3902 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3903 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), channelTypeId,
3904 projChannelId, coordinatesId,
3905 {spirv::LiteralInteger(coordinatesChannelCount - 1)});
3906
3907 // Insert it after the channels that are consumed. The extra channels are ignored per
3908 // the SPIR-V spec.
3909 const spirv::IdRef newCoordinatesId =
3910 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3911 spirv::WriteCompositeInsert(mBuilder.getSpirvCurrentFunctionBlock(), coordinatesTypeId,
3912 newCoordinatesId, projChannelId, coordinatesId,
3913 {spirv::LiteralInteger(requiredChannelCount - 1)});
3914 coordinatesId = newCoordinatesId;
3915 }
3916 }
3917
3918 // Select the correct sample Op based on whether the Proj, Dref or Explicit variants are used.
3919 if (spirvOp == spv::OpImageSampleImplicitLod)
3920 {
3921 ASSERT(!noLodSupport);
3922 const bool isExplicitLod = lodIndex != 0 || makeLodExplicit || dPdxIndex != 0;
3923 if (isDref)
3924 {
3925 if (isProj)
3926 {
3927 spirvOp = isExplicitLod ? spv::OpImageSampleProjDrefExplicitLod
3928 : spv::OpImageSampleProjDrefImplicitLod;
3929 }
3930 else
3931 {
3932 spirvOp = isExplicitLod ? spv::OpImageSampleDrefExplicitLod
3933 : spv::OpImageSampleDrefImplicitLod;
3934 }
3935 }
3936 else
3937 {
3938 if (isProj)
3939 {
3940 spirvOp = isExplicitLod ? spv::OpImageSampleProjExplicitLod
3941 : spv::OpImageSampleProjImplicitLod;
3942 }
3943 else
3944 {
3945 spirvOp =
3946 isExplicitLod ? spv::OpImageSampleExplicitLod : spv::OpImageSampleImplicitLod;
3947 }
3948 }
3949 }
3950 if (spirvOp == spv::OpImageGather && isDref)
3951 {
3952 spirvOp = spv::OpImageDrefGather;
3953 }
3954
3955 spirv::IdRef result;
3956 if (spirvOp != spv::OpImageWrite)
3957 {
3958 result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3959 }
3960
3961 switch (spirvOp)
3962 {
3963 case spv::OpImageQuerySizeLod:
3964 mBuilder.addCapability(spv::CapabilityImageQuery);
3965 ASSERT(imageOperandsList.size() == 1);
3966 spirv::WriteImageQuerySizeLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3967 result, image, imageOperandsList[0]);
3968 break;
3969 case spv::OpImageQuerySize:
3970 mBuilder.addCapability(spv::CapabilityImageQuery);
3971 spirv::WriteImageQuerySize(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3972 result, image);
3973 break;
3974 case spv::OpImageQueryLod:
3975 mBuilder.addCapability(spv::CapabilityImageQuery);
3976 spirv::WriteImageQueryLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
3977 image, coordinatesId);
3978 break;
3979 case spv::OpImageQueryLevels:
3980 mBuilder.addCapability(spv::CapabilityImageQuery);
3981 spirv::WriteImageQueryLevels(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3982 result, image);
3983 break;
3984 case spv::OpImageQuerySamples:
3985 mBuilder.addCapability(spv::CapabilityImageQuery);
3986 spirv::WriteImageQuerySamples(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3987 result, image);
3988 break;
3989 case spv::OpImageSampleImplicitLod:
3990 spirv::WriteImageSampleImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3991 resultTypeId, result, image, coordinatesId,
3992 imageOperands, imageOperandsList);
3993 break;
3994 case spv::OpImageSampleExplicitLod:
3995 spirv::WriteImageSampleExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3996 resultTypeId, result, image, coordinatesId,
3997 *imageOperands, imageOperandsList);
3998 break;
3999 case spv::OpImageSampleDrefImplicitLod:
4000 spirv::WriteImageSampleDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4001 resultTypeId, result, image, coordinatesId,
4002 drefId, imageOperands, imageOperandsList);
4003 break;
4004 case spv::OpImageSampleDrefExplicitLod:
4005 spirv::WriteImageSampleDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4006 resultTypeId, result, image, coordinatesId,
4007 drefId, *imageOperands, imageOperandsList);
4008 break;
4009 case spv::OpImageSampleProjImplicitLod:
4010 spirv::WriteImageSampleProjImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4011 resultTypeId, result, image, coordinatesId,
4012 imageOperands, imageOperandsList);
4013 break;
4014 case spv::OpImageSampleProjExplicitLod:
4015 spirv::WriteImageSampleProjExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4016 resultTypeId, result, image, coordinatesId,
4017 *imageOperands, imageOperandsList);
4018 break;
4019 case spv::OpImageSampleProjDrefImplicitLod:
4020 spirv::WriteImageSampleProjDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4021 resultTypeId, result, image, coordinatesId,
4022 drefId, imageOperands, imageOperandsList);
4023 break;
4024 case spv::OpImageSampleProjDrefExplicitLod:
4025 spirv::WriteImageSampleProjDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4026 resultTypeId, result, image, coordinatesId,
4027 drefId, *imageOperands, imageOperandsList);
4028 break;
4029 case spv::OpImageFetch:
4030 spirv::WriteImageFetch(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4031 image, coordinatesId, imageOperands, imageOperandsList);
4032 break;
4033 case spv::OpImageGather:
4034 spirv::WriteImageGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4035 image, coordinatesId, gatherComponentId, imageOperands,
4036 imageOperandsList);
4037 break;
4038 case spv::OpImageDrefGather:
4039 spirv::WriteImageDrefGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4040 result, image, coordinatesId, drefId, imageOperands,
4041 imageOperandsList);
4042 break;
4043 case spv::OpImageRead:
4044 spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4045 image, coordinatesId, imageOperands, imageOperandsList);
4046 break;
4047 case spv::OpImageWrite:
4048 spirv::WriteImageWrite(mBuilder.getSpirvCurrentFunctionBlock(), image, coordinatesId,
4049 dataId, imageOperands, imageOperandsList);
4050 break;
4051 default:
4052 UNREACHABLE();
4053 }
4054
4055 return result;
4056 }
4057
createSubpassLoadBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)4058 spirv::IdRef OutputSPIRVTraverser::createSubpassLoadBuiltIn(TIntermOperator *node,
4059 spirv::IdRef resultTypeId)
4060 {
4061 // Load the parameters.
4062 spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
4063 const spirv::IdRef image = parameters[0];
4064
4065 // If multisampled, an additional parameter specifies the sample. This is passed through as an
4066 // extra image operand.
4067 const bool hasSampleParam = parameters.size() == 2;
4068 const spv::ImageOperandsMask operandsMask =
4069 hasSampleParam ? spv::ImageOperandsSampleMask : spv::ImageOperandsMaskNone;
4070 spirv::IdRefList imageOperandsList;
4071 if (hasSampleParam)
4072 {
4073 imageOperandsList.push_back(parameters[1]);
4074 }
4075
4076 // |subpassLoad| is implemented with OpImageRead. This OP takes a coordinate, which is unused
4077 // and is set to (0, 0) here.
4078 const spirv::IdRef coordId = mBuilder.getNullConstant(mBuilder.getBasicTypeId(EbtUInt, 2));
4079
4080 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4081 spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, image,
4082 coordId, hasSampleParam ? &operandsMask : nullptr, imageOperandsList);
4083
4084 return result;
4085 }
4086
createInterpolate(TIntermOperator * node,spirv::IdRef resultTypeId)4087 spirv::IdRef OutputSPIRVTraverser::createInterpolate(TIntermOperator *node,
4088 spirv::IdRef resultTypeId)
4089 {
4090 spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
4091
4092 mBuilder.addCapability(spv::CapabilityInterpolationFunction);
4093
4094 switch (node->getOp())
4095 {
4096 case EOpInterpolateAtCentroid:
4097 extendedInst = spv::GLSLstd450InterpolateAtCentroid;
4098 break;
4099 case EOpInterpolateAtSample:
4100 extendedInst = spv::GLSLstd450InterpolateAtSample;
4101 break;
4102 case EOpInterpolateAtOffset:
4103 extendedInst = spv::GLSLstd450InterpolateAtOffset;
4104 break;
4105 default:
4106 UNREACHABLE();
4107 }
4108
4109 size_t childCount = node->getChildCount();
4110
4111 spirv::IdRefList parameters;
4112
4113 // interpolateAt* takes the interpolant as the first argument, *pointer* to which needs to be
4114 // passed to the instruction. Except interpolateAtCentroid, another parameter follows.
4115 parameters.push_back(accessChainCollapse(&mNodeData[mNodeData.size() - childCount]));
4116 if (childCount > 1)
4117 {
4118 parameters.push_back(accessChainLoad(
4119 &mNodeData.back(), node->getChildNode(1)->getAsTyped()->getType(), nullptr));
4120 }
4121
4122 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4123
4124 spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4125 mBuilder.getExtInstImportIdStd(),
4126 spirv::LiteralExtInstInteger(extendedInst), parameters);
4127
4128 return result;
4129 }
4130
castBasicType(spirv::IdRef value,const TType & valueType,const TType & expectedType,spirv::IdRef * resultTypeIdOut)4131 spirv::IdRef OutputSPIRVTraverser::castBasicType(spirv::IdRef value,
4132 const TType &valueType,
4133 const TType &expectedType,
4134 spirv::IdRef *resultTypeIdOut)
4135 {
4136 const TBasicType expectedBasicType = expectedType.getBasicType();
4137 if (valueType.getBasicType() == expectedBasicType)
4138 {
4139 return value;
4140 }
4141
4142 // Make sure no attempt is made to cast a matrix to int/uint.
4143 ASSERT(!valueType.isMatrix() || expectedBasicType == EbtFloat);
4144
4145 SpirvType valueSpirvType = mBuilder.getSpirvType(valueType, {});
4146 valueSpirvType.type = expectedBasicType;
4147 valueSpirvType.typeSpec.isOrHasBoolInInterfaceBlock = false;
4148 const spirv::IdRef castTypeId = mBuilder.getSpirvTypeData(valueSpirvType, nullptr).id;
4149
4150 const spirv::IdRef castValue = mBuilder.getNewId(mBuilder.getDecorations(expectedType));
4151
4152 // Write the instruction that casts between types. Different instructions are used based on the
4153 // types being converted.
4154 //
4155 // - int/uint <-> float: OpConvert*To*
4156 // - int <-> uint: OpBitcast
4157 // - bool --> int/uint/float: OpSelect with 0 and 1
4158 // - int/uint --> bool: OPINotEqual 0
4159 // - float --> bool: OpFUnordNotEqual 0
4160
4161 WriteUnaryOp writeUnaryOp = nullptr;
4162 WriteBinaryOp writeBinaryOp = nullptr;
4163 WriteTernaryOp writeTernaryOp = nullptr;
4164
4165 spirv::IdRef zero;
4166 spirv::IdRef one;
4167
4168 switch (valueType.getBasicType())
4169 {
4170 case EbtFloat:
4171 switch (expectedBasicType)
4172 {
4173 case EbtInt:
4174 writeUnaryOp = spirv::WriteConvertFToS;
4175 break;
4176 case EbtUInt:
4177 writeUnaryOp = spirv::WriteConvertFToU;
4178 break;
4179 case EbtBool:
4180 zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
4181 writeBinaryOp = spirv::WriteFUnordNotEqual;
4182 break;
4183 default:
4184 UNREACHABLE();
4185 }
4186 break;
4187
4188 case EbtInt:
4189 case EbtUInt:
4190 switch (expectedBasicType)
4191 {
4192 case EbtFloat:
4193 writeUnaryOp = valueType.getBasicType() == EbtInt ? spirv::WriteConvertSToF
4194 : spirv::WriteConvertUToF;
4195 break;
4196 case EbtInt:
4197 case EbtUInt:
4198 writeUnaryOp = spirv::WriteBitcast;
4199 break;
4200 case EbtBool:
4201 zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4202 writeBinaryOp = spirv::WriteINotEqual;
4203 break;
4204 default:
4205 UNREACHABLE();
4206 }
4207 break;
4208
4209 case EbtBool:
4210 writeTernaryOp = spirv::WriteSelect;
4211 switch (expectedBasicType)
4212 {
4213 case EbtFloat:
4214 zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
4215 one = mBuilder.getVecConstant(1, valueType.getNominalSize());
4216 break;
4217 case EbtInt:
4218 zero = mBuilder.getIvecConstant(0, valueType.getNominalSize());
4219 one = mBuilder.getIvecConstant(1, valueType.getNominalSize());
4220 break;
4221 case EbtUInt:
4222 zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4223 one = mBuilder.getUvecConstant(1, valueType.getNominalSize());
4224 break;
4225 default:
4226 UNREACHABLE();
4227 }
4228 break;
4229
4230 default:
4231 UNREACHABLE();
4232 }
4233
4234 if (writeUnaryOp)
4235 {
4236 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value);
4237 }
4238 else if (writeBinaryOp)
4239 {
4240 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, zero);
4241 }
4242 else
4243 {
4244 ASSERT(writeTernaryOp);
4245 writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, one,
4246 zero);
4247 }
4248
4249 if (resultTypeIdOut)
4250 {
4251 *resultTypeIdOut = castTypeId;
4252 }
4253
4254 return castValue;
4255 }
4256
cast(spirv::IdRef value,const TType & valueType,const SpirvTypeSpec & valueTypeSpec,const SpirvTypeSpec & expectedTypeSpec,spirv::IdRef * resultTypeIdOut)4257 spirv::IdRef OutputSPIRVTraverser::cast(spirv::IdRef value,
4258 const TType &valueType,
4259 const SpirvTypeSpec &valueTypeSpec,
4260 const SpirvTypeSpec &expectedTypeSpec,
4261 spirv::IdRef *resultTypeIdOut)
4262 {
4263 // If there's no difference in type specialization, there's nothing to cast.
4264 if (valueTypeSpec.blockStorage == expectedTypeSpec.blockStorage &&
4265 valueTypeSpec.isInvariantBlock == expectedTypeSpec.isInvariantBlock &&
4266 valueTypeSpec.isRowMajorQualifiedBlock == expectedTypeSpec.isRowMajorQualifiedBlock &&
4267 valueTypeSpec.isRowMajorQualifiedArray == expectedTypeSpec.isRowMajorQualifiedArray &&
4268 valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock &&
4269 valueTypeSpec.isPatchIOBlock == expectedTypeSpec.isPatchIOBlock)
4270 {
4271 return value;
4272 }
4273
4274 // At this point, a value is loaded with the |valueType| GLSL type which is of a SPIR-V type
4275 // specialized by |valueTypeSpec|. However, it's being assigned (for example through operator=,
4276 // used in a constructor or passed as a function argument) where the same GLSL type is expected
4277 // but with different SPIR-V type specialization (|expectedTypeSpec|).
4278 //
4279 // If SPIR-V 1.4 is available, use OpCopyLogical if possible. OpCopyLogical works on arrays and
4280 // structs, and only if the types are logically the same. This means that arrays and structs
4281 // can be copied with this instruction despite their SpirvTypeSpec being different. The only
4282 // exception is if there is a mismatch in the isOrHasBoolInInterfaceBlock type specialization
4283 // as it actually changes the type of the struct members.
4284 if (mCompileOptions.emitSPIRV14 && (valueType.isArray() || valueType.getStruct() != nullptr) &&
4285 valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock)
4286 {
4287 const spirv::IdRef expectedTypeId =
4288 mBuilder.getTypeDataOverrideTypeSpec(valueType, expectedTypeSpec).id;
4289 const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
4290
4291 spirv::WriteCopyLogical(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId, expectedId,
4292 value);
4293 if (resultTypeIdOut)
4294 {
4295 *resultTypeIdOut = expectedTypeId;
4296 }
4297 return expectedId;
4298 }
4299
4300 // The following code recursively copies the array elements or struct fields and then constructs
4301 // the final result with the expected SPIR-V type.
4302
4303 // Interface blocks cannot be copied or passed as parameters in GLSL.
4304 ASSERT(!valueType.isInterfaceBlock());
4305
4306 spirv::IdRefList constituents;
4307
4308 if (valueType.isArray())
4309 {
4310 // Find the SPIR-V type specialization for the element type.
4311 SpirvTypeSpec valueElementTypeSpec = valueTypeSpec;
4312 SpirvTypeSpec expectedElementTypeSpec = expectedTypeSpec;
4313
4314 const bool isElementBlock = valueType.getStruct() != nullptr;
4315 const bool isElementArray = valueType.isArrayOfArrays();
4316
4317 valueElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4318 expectedElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4319
4320 // Get the element type id.
4321 TType elementType(valueType);
4322 elementType.toArrayElementType();
4323
4324 const spirv::IdRef elementTypeId =
4325 mBuilder.getTypeDataOverrideTypeSpec(elementType, valueElementTypeSpec).id;
4326
4327 const SpirvDecorations elementDecorations = mBuilder.getDecorations(elementType);
4328
4329 // Extract each element of the array and cast it to the expected type.
4330 for (unsigned int elementIndex = 0; elementIndex < valueType.getOutermostArraySize();
4331 ++elementIndex)
4332 {
4333 const spirv::IdRef elementId = mBuilder.getNewId(elementDecorations);
4334 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), elementTypeId,
4335 elementId, value, {spirv::LiteralInteger(elementIndex)});
4336
4337 constituents.push_back(cast(elementId, elementType, valueElementTypeSpec,
4338 expectedElementTypeSpec, nullptr));
4339 }
4340 }
4341 else if (valueType.getStruct() != nullptr)
4342 {
4343 uint32_t fieldIndex = 0;
4344
4345 // Extract each field of the struct and cast it to the expected type.
4346 for (const TField *field : valueType.getStruct()->fields())
4347 {
4348 const TType &fieldType = *field->type();
4349
4350 // Find the SPIR-V type specialization for the field type.
4351 SpirvTypeSpec valueFieldTypeSpec = valueTypeSpec;
4352 SpirvTypeSpec expectedFieldTypeSpec = expectedTypeSpec;
4353
4354 valueFieldTypeSpec.onBlockFieldSelection(fieldType);
4355 expectedFieldTypeSpec.onBlockFieldSelection(fieldType);
4356
4357 // Get the field type id.
4358 const spirv::IdRef fieldTypeId =
4359 mBuilder.getTypeDataOverrideTypeSpec(fieldType, valueFieldTypeSpec).id;
4360
4361 // Extract the field.
4362 const spirv::IdRef fieldId = mBuilder.getNewId(mBuilder.getDecorations(fieldType));
4363 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId,
4364 fieldId, value, {spirv::LiteralInteger(fieldIndex++)});
4365
4366 constituents.push_back(
4367 cast(fieldId, fieldType, valueFieldTypeSpec, expectedFieldTypeSpec, nullptr));
4368 }
4369 }
4370 else
4371 {
4372 // Bool types in interface blocks are emulated with uint. bool<->uint cast is done here.
4373 ASSERT(valueType.getBasicType() == EbtBool);
4374 ASSERT(valueTypeSpec.isOrHasBoolInInterfaceBlock ||
4375 expectedTypeSpec.isOrHasBoolInInterfaceBlock);
4376
4377 TType emulatedValueType(valueType);
4378 emulatedValueType.setBasicType(EbtUInt);
4379 emulatedValueType.setPrecise(EbpLow);
4380
4381 // If value is loaded as uint, it needs to change to bool. If it's bool, it needs to change
4382 // to uint before storage.
4383 if (valueTypeSpec.isOrHasBoolInInterfaceBlock)
4384 {
4385 return castBasicType(value, emulatedValueType, valueType, resultTypeIdOut);
4386 }
4387 else
4388 {
4389 return castBasicType(value, valueType, emulatedValueType, resultTypeIdOut);
4390 }
4391 }
4392
4393 // Construct the value with the expected type from its cast constituents.
4394 const spirv::IdRef expectedTypeId =
4395 mBuilder.getTypeDataOverrideTypeSpec(valueType, expectedTypeSpec).id;
4396 const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
4397
4398 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId,
4399 expectedId, constituents);
4400
4401 if (resultTypeIdOut)
4402 {
4403 *resultTypeIdOut = expectedTypeId;
4404 }
4405
4406 return expectedId;
4407 }
4408
extendScalarParamsToVector(TIntermOperator * node,spirv::IdRef resultTypeId,spirv::IdRefList * parameters)4409 void OutputSPIRVTraverser::extendScalarParamsToVector(TIntermOperator *node,
4410 spirv::IdRef resultTypeId,
4411 spirv::IdRefList *parameters)
4412 {
4413 const TType &type = node->getType();
4414 if (type.isScalar())
4415 {
4416 // Nothing to do if the operation is applied to scalars.
4417 return;
4418 }
4419
4420 const size_t childCount = node->getChildCount();
4421
4422 for (size_t childIndex = 0; childIndex < childCount; ++childIndex)
4423 {
4424 const TType &childType = node->getChildNode(childIndex)->getAsTyped()->getType();
4425
4426 // If the child is a scalar, replicate it to form a vector of the right size.
4427 if (childType.isScalar())
4428 {
4429 TType vectorType(type);
4430 if (vectorType.isMatrix())
4431 {
4432 vectorType.toMatrixColumnType();
4433 }
4434 (*parameters)[childIndex] = createConstructorVectorFromScalar(
4435 childType, vectorType, resultTypeId, {{(*parameters)[childIndex]}});
4436 }
4437 }
4438 }
4439
reduceBoolVector(TOperator op,const spirv::IdRefList & valueIds,spirv::IdRef typeId,const SpirvDecorations & decorations)4440 spirv::IdRef OutputSPIRVTraverser::reduceBoolVector(TOperator op,
4441 const spirv::IdRefList &valueIds,
4442 spirv::IdRef typeId,
4443 const SpirvDecorations &decorations)
4444 {
4445 if (valueIds.size() == 2)
4446 {
4447 // If two values are given, and/or them directly.
4448 WriteBinaryOp writeBinaryOp =
4449 op == EOpEqual ? spirv::WriteLogicalAnd : spirv::WriteLogicalOr;
4450 const spirv::IdRef result = mBuilder.getNewId(decorations);
4451
4452 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueIds[0],
4453 valueIds[1]);
4454 return result;
4455 }
4456
4457 WriteUnaryOp writeUnaryOp = op == EOpEqual ? spirv::WriteAll : spirv::WriteAny;
4458 spirv::IdRef valueId = valueIds[0];
4459
4460 if (valueIds.size() > 2)
4461 {
4462 // If multiple values are given, construct a bool vector out of them first.
4463 const spirv::IdRef bvecTypeId = mBuilder.getBasicTypeId(EbtBool, valueIds.size());
4464 valueId = {mBuilder.getNewId(decorations)};
4465
4466 spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), bvecTypeId, valueId,
4467 valueIds);
4468 }
4469
4470 const spirv::IdRef result = mBuilder.getNewId(decorations);
4471 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueId);
4472
4473 return result;
4474 }
4475
createCompareImpl(TOperator op,const TType & operandType,spirv::IdRef resultTypeId,spirv::IdRef leftId,spirv::IdRef rightId,const SpirvDecorations & operandDecorations,const SpirvDecorations & resultDecorations,spirv::LiteralIntegerList * currentAccessChain,spirv::IdRefList * intermediateResultsOut)4476 void OutputSPIRVTraverser::createCompareImpl(TOperator op,
4477 const TType &operandType,
4478 spirv::IdRef resultTypeId,
4479 spirv::IdRef leftId,
4480 spirv::IdRef rightId,
4481 const SpirvDecorations &operandDecorations,
4482 const SpirvDecorations &resultDecorations,
4483 spirv::LiteralIntegerList *currentAccessChain,
4484 spirv::IdRefList *intermediateResultsOut)
4485 {
4486 const TBasicType basicType = operandType.getBasicType();
4487 const bool isFloat = basicType == EbtFloat;
4488 const bool isBool = basicType == EbtBool;
4489
4490 WriteBinaryOp writeBinaryOp = nullptr;
4491
4492 // For arrays, compare them element by element.
4493 if (operandType.isArray())
4494 {
4495 TType elementType(operandType);
4496 elementType.toArrayElementType();
4497
4498 currentAccessChain->emplace_back();
4499 for (unsigned int elementIndex = 0; elementIndex < operandType.getOutermostArraySize();
4500 ++elementIndex)
4501 {
4502 // Select the current element.
4503 currentAccessChain->back() = spirv::LiteralInteger(elementIndex);
4504
4505 // Compare and accumulate the results.
4506 createCompareImpl(op, elementType, resultTypeId, leftId, rightId, operandDecorations,
4507 resultDecorations, currentAccessChain, intermediateResultsOut);
4508 }
4509 currentAccessChain->pop_back();
4510
4511 return;
4512 }
4513
4514 // For structs, compare them field by field.
4515 if (operandType.getStruct() != nullptr)
4516 {
4517 uint32_t fieldIndex = 0;
4518
4519 currentAccessChain->emplace_back();
4520 for (const TField *field : operandType.getStruct()->fields())
4521 {
4522 // Select the current field.
4523 currentAccessChain->back() = spirv::LiteralInteger(fieldIndex++);
4524
4525 // Compare and accumulate the results.
4526 createCompareImpl(op, *field->type(), resultTypeId, leftId, rightId, operandDecorations,
4527 resultDecorations, currentAccessChain, intermediateResultsOut);
4528 }
4529 currentAccessChain->pop_back();
4530
4531 return;
4532 }
4533
4534 // For matrices, compare them column by column.
4535 if (operandType.isMatrix())
4536 {
4537 TType columnType(operandType);
4538 columnType.toMatrixColumnType();
4539
4540 currentAccessChain->emplace_back();
4541 for (uint8_t columnIndex = 0; columnIndex < operandType.getCols(); ++columnIndex)
4542 {
4543 // Select the current column.
4544 currentAccessChain->back() = spirv::LiteralInteger(columnIndex);
4545
4546 // Compare and accumulate the results.
4547 createCompareImpl(op, columnType, resultTypeId, leftId, rightId, operandDecorations,
4548 resultDecorations, currentAccessChain, intermediateResultsOut);
4549 }
4550 currentAccessChain->pop_back();
4551
4552 return;
4553 }
4554
4555 // For scalars and vectors generate a single instruction for comparison.
4556 if (op == EOpEqual)
4557 {
4558 if (isFloat)
4559 writeBinaryOp = spirv::WriteFOrdEqual;
4560 else if (isBool)
4561 writeBinaryOp = spirv::WriteLogicalEqual;
4562 else
4563 writeBinaryOp = spirv::WriteIEqual;
4564 }
4565 else
4566 {
4567 ASSERT(op == EOpNotEqual);
4568
4569 if (isFloat)
4570 writeBinaryOp = spirv::WriteFUnordNotEqual;
4571 else if (isBool)
4572 writeBinaryOp = spirv::WriteLogicalNotEqual;
4573 else
4574 writeBinaryOp = spirv::WriteINotEqual;
4575 }
4576
4577 // Extract the scalar and vector from composite types, if any.
4578 spirv::IdRef leftComponentId = leftId;
4579 spirv::IdRef rightComponentId = rightId;
4580 if (!currentAccessChain->empty())
4581 {
4582 leftComponentId = mBuilder.getNewId(operandDecorations);
4583 rightComponentId = mBuilder.getNewId(operandDecorations);
4584
4585 const spirv::IdRef componentTypeId =
4586 mBuilder.getBasicTypeId(operandType.getBasicType(), operandType.getNominalSize());
4587
4588 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4589 leftComponentId, leftId, *currentAccessChain);
4590 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4591 rightComponentId, rightId, *currentAccessChain);
4592 }
4593
4594 const bool reduceResult = !operandType.isScalar();
4595 spirv::IdRef result = mBuilder.getNewId({});
4596 spirv::IdRef opResultTypeId = resultTypeId;
4597 if (reduceResult)
4598 {
4599 opResultTypeId = mBuilder.getBasicTypeId(EbtBool, operandType.getNominalSize());
4600 }
4601
4602 // Write the comparison operation itself.
4603 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), opResultTypeId, result, leftComponentId,
4604 rightComponentId);
4605
4606 // If it's a vector, reduce the result.
4607 if (reduceResult)
4608 {
4609 result = reduceBoolVector(op, {result}, resultTypeId, resultDecorations);
4610 }
4611
4612 intermediateResultsOut->push_back(result);
4613 }
4614
makeBuiltInOutputStructType(TIntermOperator * node,size_t lvalueCount)4615 spirv::IdRef OutputSPIRVTraverser::makeBuiltInOutputStructType(TIntermOperator *node,
4616 size_t lvalueCount)
4617 {
4618 // The built-ins with lvalues are in one of the following forms:
4619 //
4620 // - lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4621 // - builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4622 //
4623 // In SPIR-V, the result of all these instructions is a struct { lsb; msb; }.
4624
4625 const size_t childCount = node->getChildCount();
4626 ASSERT(childCount >= 2);
4627
4628 TIntermTyped *lastChild = node->getChildNode(childCount - 1)->getAsTyped();
4629 TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4630
4631 const TType &lsbType = lvalueCount == 1 ? node->getType() : lastChild->getType();
4632 const TType &msbType = lvalueCount == 1 ? lastChild->getType() : beforeLastChild->getType();
4633
4634 ASSERT(lsbType.isScalar() || lsbType.isVector());
4635 ASSERT(msbType.isScalar() || msbType.isVector());
4636
4637 const BuiltInResultStruct key = {
4638 lsbType.getBasicType(),
4639 msbType.getBasicType(),
4640 static_cast<uint32_t>(lsbType.getNominalSize()),
4641 static_cast<uint32_t>(msbType.getNominalSize()),
4642 };
4643
4644 auto iter = mBuiltInResultStructMap.find(key);
4645 if (iter == mBuiltInResultStructMap.end())
4646 {
4647 // Create a TStructure and TType for the required structure.
4648 TType *lsbTypeCopy = new TType(lsbType.getBasicType(), lsbType.getNominalSize(), 1);
4649 TType *msbTypeCopy = new TType(msbType.getBasicType(), msbType.getNominalSize(), 1);
4650
4651 TFieldList *fields = new TFieldList;
4652 fields->push_back(
4653 new TField(lsbTypeCopy, ImmutableString("lsb"), {}, SymbolType::AngleInternal));
4654 fields->push_back(
4655 new TField(msbTypeCopy, ImmutableString("msb"), {}, SymbolType::AngleInternal));
4656
4657 TStructure *structure =
4658 new TStructure(&mCompiler->getSymbolTable(), ImmutableString("BuiltInResultType"),
4659 fields, SymbolType::AngleInternal);
4660
4661 TType structType(structure, true);
4662
4663 // Get an id for the type and store in the hash map.
4664 const spirv::IdRef structTypeId = mBuilder.getTypeData(structType, {}).id;
4665 iter = mBuiltInResultStructMap.insert({key, structTypeId}).first;
4666 }
4667
4668 return iter->second;
4669 }
4670
4671 // Once the builtin instruction is generated, the two return values are extracted from the
4672 // struct. These are written to the return value (if any) and the out parameters.
storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator * node,size_t lvalueCount,spirv::IdRef structValue,spirv::IdRef returnValue,spirv::IdRef returnValueType)4673 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamsAndReturnValue(
4674 TIntermOperator *node,
4675 size_t lvalueCount,
4676 spirv::IdRef structValue,
4677 spirv::IdRef returnValue,
4678 spirv::IdRef returnValueType)
4679 {
4680 const size_t childCount = node->getChildCount();
4681 ASSERT(childCount >= 2);
4682
4683 TIntermTyped *lastChild = node->getChildNode(childCount - 1)->getAsTyped();
4684 TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4685
4686 if (lvalueCount == 1)
4687 {
4688 // The built-in is the form:
4689 //
4690 // lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4691
4692 // Field 0 is lsb, which is extracted as the builtin's return value.
4693 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), returnValueType,
4694 returnValue, structValue, {spirv::LiteralInteger(0)});
4695
4696 // Field 1 is msb, which is extracted and stored through the out parameter.
4697 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4698 structValue, 1);
4699 }
4700 else
4701 {
4702 // The built-in is the form:
4703 //
4704 // builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4705 ASSERT(lvalueCount == 2);
4706
4707 // Field 0 is lsb, which is extracted and stored through the second out parameter.
4708 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4709 structValue, 0);
4710
4711 // Field 1 is msb, which is extracted and stored through the first out parameter.
4712 storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 2], beforeLastChild,
4713 structValue, 1);
4714 }
4715 }
4716
storeBuiltInStructOutputInParamHelper(NodeData * data,TIntermTyped * param,spirv::IdRef structValue,uint32_t fieldIndex)4717 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamHelper(NodeData *data,
4718 TIntermTyped *param,
4719 spirv::IdRef structValue,
4720 uint32_t fieldIndex)
4721 {
4722 spirv::IdRef fieldTypeId = mBuilder.getTypeData(param->getType(), {}).id;
4723 spirv::IdRef fieldValueId = mBuilder.getNewId(mBuilder.getDecorations(param->getType()));
4724
4725 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId, fieldValueId,
4726 structValue, {spirv::LiteralInteger(fieldIndex)});
4727
4728 accessChainStore(data, fieldValueId, param->getType());
4729 }
4730
visitSymbol(TIntermSymbol * node)4731 void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node)
4732 {
4733 // No-op visits to symbols that are being declared. They are handled in visitDeclaration.
4734 if (mIsSymbolBeingDeclared)
4735 {
4736 // Make sure this does not affect other symbols, for example in the initializer expression.
4737 mIsSymbolBeingDeclared = false;
4738 return;
4739 }
4740
4741 mNodeData.emplace_back();
4742
4743 // The symbol is either:
4744 //
4745 // - A specialization constant
4746 // - A variable (local, varying etc)
4747 // - An interface block
4748 // - A field of an unnamed interface block
4749 //
4750 // Specialization constants in SPIR-V are treated largely like constants, in which case make
4751 // this behave like visitConstantUnion().
4752
4753 const TType &type = node->getType();
4754 const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
4755 const TSymbol *symbol = interfaceBlock;
4756 if (interfaceBlock == nullptr)
4757 {
4758 symbol = &node->variable();
4759 }
4760
4761 // Track the properties that lead to the symbol's specific SPIR-V type based on the GLSL type.
4762 // They are needed to determine the derived type in an access chain, but are not promoted in
4763 // intermediate nodes' TTypes.
4764 SpirvTypeSpec typeSpec;
4765 typeSpec.inferDefaults(type, mCompiler);
4766
4767 const spirv::IdRef typeId = mBuilder.getTypeData(type, typeSpec).id;
4768
4769 // If the symbol is a const variable, a const function parameter or specialization constant,
4770 // create an rvalue.
4771 if (type.getQualifier() == EvqConst || type.getQualifier() == EvqParamConst ||
4772 type.getQualifier() == EvqSpecConst)
4773 {
4774 ASSERT(interfaceBlock == nullptr);
4775 ASSERT(mSymbolIdMap.count(symbol) > 0);
4776 nodeDataInitRValue(&mNodeData.back(), mSymbolIdMap[symbol], typeId);
4777 return;
4778 }
4779
4780 // Otherwise create an lvalue.
4781 spv::StorageClass storageClass;
4782 const spirv::IdRef symbolId = getSymbolIdAndStorageClass(symbol, type, &storageClass);
4783
4784 nodeDataInitLValue(&mNodeData.back(), symbolId, typeId, storageClass, typeSpec);
4785
4786 // If a field of a nameless interface block, create an access chain.
4787 if (type.getInterfaceBlock() && !type.isInterfaceBlock())
4788 {
4789 uint32_t fieldIndex = static_cast<uint32_t>(type.getInterfaceBlockFieldIndex());
4790 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(fieldIndex), typeId);
4791 }
4792
4793 // Add gl_PerVertex capabilities only if the field is actually used.
4794 switch (type.getQualifier())
4795 {
4796 case EvqClipDistance:
4797 mBuilder.addCapability(spv::CapabilityClipDistance);
4798 break;
4799 case EvqCullDistance:
4800 mBuilder.addCapability(spv::CapabilityCullDistance);
4801 break;
4802 default:
4803 break;
4804 }
4805 }
4806
visitConstantUnion(TIntermConstantUnion * node)4807 void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node)
4808 {
4809 mNodeData.emplace_back();
4810
4811 const TType &type = node->getType();
4812
4813 // Find out the expected type for this constant, so it can be cast right away and not need an
4814 // instruction to do that.
4815 TIntermNode *parent = getParentNode();
4816 const size_t childIndex = getParentChildIndex(PreVisit);
4817
4818 TBasicType expectedBasicType = type.getBasicType();
4819 if (parent->getAsAggregate())
4820 {
4821 TIntermAggregate *parentAggregate = parent->getAsAggregate();
4822
4823 // Note that only constructors can cast a type. There are two possibilities:
4824 //
4825 // - It's a struct constructor: The basic type must match that of the corresponding field of
4826 // the struct.
4827 // - It's a non struct constructor: The basic type must match that of the type being
4828 // constructed.
4829 if (parentAggregate->isConstructor())
4830 {
4831 const TType &parentType = parentAggregate->getType();
4832 const TStructure *structure = parentType.getStruct();
4833
4834 if (structure != nullptr && !parentType.isArray())
4835 {
4836 expectedBasicType = structure->fields()[childIndex]->type()->getBasicType();
4837 }
4838 else
4839 {
4840 expectedBasicType = parentAggregate->getType().getBasicType();
4841 }
4842 }
4843 }
4844
4845 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
4846 const spirv::IdRef constId = createConstant(type, expectedBasicType, node->getConstantValue(),
4847 node->isConstantNullValue());
4848
4849 nodeDataInitRValue(&mNodeData.back(), constId, typeId);
4850 }
4851
visitSwizzle(Visit visit,TIntermSwizzle * node)4852 bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
4853 {
4854 // Constants are expected to be folded.
4855 ASSERT(!node->hasConstantValue());
4856
4857 if (visit == PreVisit)
4858 {
4859 // Don't add an entry to the stack. The child will create one, which we won't pop.
4860 return true;
4861 }
4862
4863 ASSERT(visit == PostVisit);
4864 ASSERT(mNodeData.size() >= 1);
4865
4866 const TType &vectorType = node->getOperand()->getType();
4867 const uint8_t vectorComponentCount = static_cast<uint8_t>(vectorType.getNominalSize());
4868 const TVector<int> &swizzle = node->getSwizzleOffsets();
4869
4870 // As an optimization, do nothing if the swizzle is selecting all the components of the vector
4871 // in order.
4872 bool isIdentity = swizzle.size() == vectorComponentCount;
4873 for (size_t index = 0; index < swizzle.size(); ++index)
4874 {
4875 isIdentity = isIdentity && static_cast<size_t>(swizzle[index]) == index;
4876 }
4877
4878 if (isIdentity)
4879 {
4880 return true;
4881 }
4882
4883 accessChainOnPush(&mNodeData.back(), vectorType, 0);
4884
4885 const spirv::IdRef typeId =
4886 mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4887
4888 accessChainPushSwizzle(&mNodeData.back(), swizzle, typeId, vectorComponentCount);
4889
4890 return true;
4891 }
4892
visitBinary(Visit visit,TIntermBinary * node)4893 bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
4894 {
4895 // Constants are expected to be folded.
4896 ASSERT(!node->hasConstantValue());
4897
4898 if (visit == PreVisit)
4899 {
4900 // Don't add an entry to the stack. The left child will create one, which we won't pop.
4901 return true;
4902 }
4903
4904 // If this is a variable initialization node, defer any code generation to visitDeclaration.
4905 if (node->getOp() == EOpInitialize)
4906 {
4907 ASSERT(getParentNode()->getAsDeclarationNode() != nullptr);
4908 return true;
4909 }
4910
4911 if (IsShortCircuitNeeded(node))
4912 {
4913 // For && and ||, if short-circuiting behavior is needed, we need to emulate it with an
4914 // |if| construct. At this point, the left-hand side is already evaluated, so we need to
4915 // create an appropriate conditional on in-visit and visit the right-hand-side inside the
4916 // conditional block. On post-visit, OpPhi is used to calculate the result.
4917 if (visit == InVisit)
4918 {
4919 startShortCircuit(node);
4920 return true;
4921 }
4922
4923 spirv::IdRef typeId;
4924 const spirv::IdRef result = endShortCircuit(node, &typeId);
4925
4926 // Replace the access chain with an rvalue that's the result.
4927 nodeDataInitRValue(&mNodeData.back(), result, typeId);
4928
4929 return true;
4930 }
4931
4932 if (visit == InVisit)
4933 {
4934 // Left child visited. Take the entry it created as the current node's.
4935 ASSERT(mNodeData.size() >= 1);
4936
4937 // As an optimization, if the index is EOpIndexDirect*, take the constant index directly and
4938 // add it to the access chain as literal.
4939 switch (node->getOp())
4940 {
4941 default:
4942 break;
4943
4944 case EOpIndexDirect:
4945 case EOpIndexDirectStruct:
4946 case EOpIndexDirectInterfaceBlock:
4947 const uint32_t index = node->getRight()->getAsConstantUnion()->getIConst(0);
4948 accessChainOnPush(&mNodeData.back(), node->getLeft()->getType(), index);
4949
4950 const spirv::IdRef typeId =
4951 mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4952 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(index), typeId);
4953
4954 // Don't visit the right child, it's already processed.
4955 return false;
4956 }
4957
4958 return true;
4959 }
4960
4961 // There are at least two entries, one for the left node and one for the right one.
4962 ASSERT(mNodeData.size() >= 2);
4963
4964 SpirvTypeSpec resultTypeSpec;
4965 if (node->getOp() == EOpIndexIndirect || node->getOp() == EOpAssign)
4966 {
4967 if (node->getOp() == EOpIndexIndirect)
4968 {
4969 accessChainOnPush(&mNodeData[mNodeData.size() - 2], node->getLeft()->getType(), 0);
4970 }
4971 resultTypeSpec = mNodeData[mNodeData.size() - 2].accessChain.typeSpec;
4972 }
4973 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), resultTypeSpec).id;
4974
4975 // For EOpIndex* operations, push the right value as an index to the left value's access chain.
4976 // For the other operations, evaluate the expression.
4977 switch (node->getOp())
4978 {
4979 case EOpIndexDirect:
4980 case EOpIndexDirectStruct:
4981 case EOpIndexDirectInterfaceBlock:
4982 UNREACHABLE();
4983 break;
4984 case EOpIndexIndirect:
4985 {
4986 // Load the index.
4987 const spirv::IdRef rightValue =
4988 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
4989 mNodeData.pop_back();
4990
4991 if (!node->getLeft()->getType().isArray() && node->getLeft()->getType().isVector())
4992 {
4993 accessChainPushDynamicComponent(&mNodeData.back(), rightValue, resultTypeId);
4994 }
4995 else
4996 {
4997 accessChainPush(&mNodeData.back(), rightValue, resultTypeId);
4998 }
4999 break;
5000 }
5001
5002 case EOpAssign:
5003 {
5004 // Load the right hand side of assignment.
5005 const spirv::IdRef rightValue =
5006 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5007 mNodeData.pop_back();
5008
5009 // Store into the access chain. Since the result of the (a = b) expression is b, change
5010 // the access chain to an unindexed rvalue which is |rightValue|.
5011 accessChainStore(&mNodeData.back(), rightValue, node->getLeft()->getType());
5012 nodeDataInitRValue(&mNodeData.back(), rightValue, resultTypeId);
5013 break;
5014 }
5015
5016 case EOpComma:
5017 // When the expression a,b is visited, all side effects of a and b are already
5018 // processed. What's left is to to replace the expression with the result of b. This
5019 // is simply done by dropping the left node and placing the right node as the result.
5020 mNodeData.erase(mNodeData.begin() + mNodeData.size() - 2);
5021 break;
5022
5023 default:
5024 const spirv::IdRef result = visitOperator(node, resultTypeId);
5025 mNodeData.pop_back();
5026 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5027 break;
5028 }
5029
5030 return true;
5031 }
5032
visitUnary(Visit visit,TIntermUnary * node)5033 bool OutputSPIRVTraverser::visitUnary(Visit visit, TIntermUnary *node)
5034 {
5035 // Constants are expected to be folded.
5036 ASSERT(!node->hasConstantValue());
5037
5038 // Special case EOpArrayLength.
5039 if (node->getOp() == EOpArrayLength)
5040 {
5041 visitArrayLength(node);
5042
5043 // Children already visited.
5044 return false;
5045 }
5046
5047 if (visit == PreVisit)
5048 {
5049 // Don't add an entry to the stack. The child will create one, which we won't pop.
5050 return true;
5051 }
5052
5053 // It's a unary operation, so there can't be an InVisit.
5054 ASSERT(visit != InVisit);
5055
5056 // There is at least on entry for the child.
5057 ASSERT(mNodeData.size() >= 1);
5058
5059 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5060 const spirv::IdRef result = visitOperator(node, resultTypeId);
5061
5062 // Keep the result as rvalue.
5063 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5064
5065 return true;
5066 }
5067
visitTernary(Visit visit,TIntermTernary * node)5068 bool OutputSPIRVTraverser::visitTernary(Visit visit, TIntermTernary *node)
5069 {
5070 if (visit == PreVisit)
5071 {
5072 // Don't add an entry to the stack. The condition will create one, which we won't pop.
5073 return true;
5074 }
5075
5076 size_t lastChildIndex = getLastTraversedChildIndex(visit);
5077
5078 // If the condition was just visited, evaluate it and decide if OpSelect could be used or an
5079 // if-else must be emitted. OpSelect is only used if neither side has a side effect. SPIR-V
5080 // prior to 1.4 requires the type to be either scalar or vector.
5081 const TType &type = node->getType();
5082 bool canUseOpSelect = (type.isScalar() || type.isVector() || mCompileOptions.emitSPIRV14) &&
5083 !node->getTrueExpression()->hasSideEffects() &&
5084 !node->getFalseExpression()->hasSideEffects();
5085
5086 // Don't use OpSelect on buggy drivers. Technically this is only needed if the two sides don't
5087 // have matching use of RelaxedPrecision, but not worth being precise about it.
5088 if (mCompileOptions.avoidOpSelectWithMismatchingRelaxedPrecision)
5089 {
5090 canUseOpSelect = false;
5091 }
5092
5093 if (lastChildIndex == 0)
5094 {
5095 const TType &conditionType = node->getCondition()->getType();
5096
5097 spirv::IdRef typeId;
5098 spirv::IdRef conditionValue = accessChainLoad(&mNodeData.back(), conditionType, &typeId);
5099
5100 // If OpSelect can be used, keep the condition for later usage.
5101 if (canUseOpSelect)
5102 {
5103 // SPIR-V prior to 1.4 requires that the condition value have as many components as the
5104 // result. So when selecting between vectors, we must replicate the condition scalar.
5105 if (!mCompileOptions.emitSPIRV14 && type.isVector())
5106 {
5107 const TType &boolVectorType =
5108 *StaticType::GetForVec<EbtBool, EbpUndefined>(EvqGlobal, type.getNominalSize());
5109 typeId =
5110 mBuilder.getBasicTypeId(conditionType.getBasicType(), type.getNominalSize());
5111 conditionValue = createConstructorVectorFromScalar(conditionType, boolVectorType,
5112 typeId, {{conditionValue}});
5113 }
5114 nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5115 return true;
5116 }
5117
5118 // Otherwise generate an if-else construct.
5119
5120 // Three blocks necessary; the true, false and merge.
5121 mBuilder.startConditional(3, false, false);
5122
5123 // Generate the branch instructions.
5124 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5125
5126 const spirv::IdRef trueBlockId = conditional->blockIds[0];
5127 const spirv::IdRef falseBlockId = conditional->blockIds[1];
5128 const spirv::IdRef mergeBlockId = conditional->blockIds.back();
5129
5130 mBuilder.writeBranchConditional(conditionValue, trueBlockId, falseBlockId, mergeBlockId);
5131 nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5132 return true;
5133 }
5134
5135 // Load the result of the true or false part, and keep it for the end. It's either used in
5136 // OpSelect or OpPhi.
5137 spirv::IdRef typeId;
5138 const spirv::IdRef value = accessChainLoad(&mNodeData.back(), type, &typeId);
5139 mNodeData.pop_back();
5140 mNodeData.back().idList.push_back(value);
5141
5142 // Additionally store the id of block that has produced the result.
5143 mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
5144
5145 if (!canUseOpSelect)
5146 {
5147 // Move on to the next block.
5148 mBuilder.writeBranchConditionalBlockEnd();
5149 }
5150
5151 // When done, generate either OpSelect or OpPhi.
5152 if (visit == PostVisit)
5153 {
5154 const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
5155
5156 ASSERT(mNodeData.back().idList.size() == 4);
5157 const spirv::IdRef trueValue = mNodeData.back().idList[0].id;
5158 const spirv::IdRef trueBlockId = mNodeData.back().idList[1].id;
5159 const spirv::IdRef falseValue = mNodeData.back().idList[2].id;
5160 const spirv::IdRef falseBlockId = mNodeData.back().idList[3].id;
5161
5162 if (canUseOpSelect)
5163 {
5164 const spirv::IdRef conditionValue = mNodeData.back().baseId;
5165
5166 spirv::WriteSelect(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5167 conditionValue, trueValue, falseValue);
5168 }
5169 else
5170 {
5171 spirv::WritePhi(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5172 {spirv::PairIdRefIdRef{trueValue, trueBlockId},
5173 spirv::PairIdRefIdRef{falseValue, falseBlockId}});
5174
5175 mBuilder.endConditional();
5176 }
5177
5178 // Replace the access chain with an rvalue that's the result.
5179 nodeDataInitRValue(&mNodeData.back(), result, typeId);
5180 }
5181
5182 return true;
5183 }
5184
visitIfElse(Visit visit,TIntermIfElse * node)5185 bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
5186 {
5187 // An if condition may or may not have an else block. When both blocks are present, the
5188 // translation is as follows:
5189 //
5190 // if (cond) { trueBody } else { falseBody }
5191 //
5192 // // pre-if block
5193 // %cond = ...
5194 // OpSelectionMerge %merge None
5195 // OpBranchConditional %cond %true %false
5196 //
5197 // %true = OpLabel
5198 // trueBody
5199 // OpBranch %merge
5200 //
5201 // %false = OpLabel
5202 // falseBody
5203 // OpBranch %merge
5204 //
5205 // // post-if block
5206 // %merge = OpLabel
5207 //
5208 // If the else block is missing, OpBranchConditional will simply jump to %merge on the false
5209 // condition and the %false block is removed. Due to the way ParseContext prunes compile-time
5210 // constant conditionals, the if block itself may also be missing, which is treated similarly.
5211
5212 // It's simpler if this function performs the traversal.
5213 ASSERT(visit == PreVisit);
5214
5215 // Visit the condition.
5216 node->getCondition()->traverse(this);
5217 const spirv::IdRef conditionValue =
5218 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
5219
5220 // If both true and false blocks are missing, there's nothing to do.
5221 if (node->getTrueBlock() == nullptr && node->getFalseBlock() == nullptr)
5222 {
5223 return false;
5224 }
5225
5226 // Create a conditional with maximum 3 blocks, one for the true block (if any), one for the
5227 // else block (if any), and one for the merge block. getChildCount() works here as it
5228 // produces an identical count.
5229 mBuilder.startConditional(node->getChildCount(), false, false);
5230
5231 // Generate the branch instructions.
5232 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5233
5234 const spirv::IdRef mergeBlock = conditional->blockIds.back();
5235 spirv::IdRef trueBlock = mergeBlock;
5236 spirv::IdRef falseBlock = mergeBlock;
5237
5238 size_t nextBlockIndex = 0;
5239 if (node->getTrueBlock())
5240 {
5241 trueBlock = conditional->blockIds[nextBlockIndex++];
5242 }
5243 if (node->getFalseBlock())
5244 {
5245 falseBlock = conditional->blockIds[nextBlockIndex++];
5246 }
5247
5248 mBuilder.writeBranchConditional(conditionValue, trueBlock, falseBlock, mergeBlock);
5249
5250 // Visit the true block, if any.
5251 if (node->getTrueBlock())
5252 {
5253 node->getTrueBlock()->traverse(this);
5254 mBuilder.writeBranchConditionalBlockEnd();
5255 }
5256
5257 // Visit the false block, if any.
5258 if (node->getFalseBlock())
5259 {
5260 node->getFalseBlock()->traverse(this);
5261 mBuilder.writeBranchConditionalBlockEnd();
5262 }
5263
5264 // Pop from the conditional stack when done.
5265 mBuilder.endConditional();
5266
5267 // Don't traverse the children, that's done already.
5268 return false;
5269 }
5270
visitSwitch(Visit visit,TIntermSwitch * node)5271 bool OutputSPIRVTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
5272 {
5273 // Take the following switch:
5274 //
5275 // switch (c)
5276 // {
5277 // case A:
5278 // ABlock;
5279 // break;
5280 // case B:
5281 // default:
5282 // BBlock;
5283 // break;
5284 // case C:
5285 // CBlock;
5286 // // fallthrough
5287 // case D:
5288 // DBlock;
5289 // }
5290 //
5291 // In SPIR-V, this is implemented similarly to the following pseudo-code:
5292 //
5293 // switch c:
5294 // A -> jump %A
5295 // B -> jump %B
5296 // C -> jump %C
5297 // D -> jump %D
5298 // default -> jump %B
5299 //
5300 // %A:
5301 // ABlock
5302 // jump %merge
5303 //
5304 // %B:
5305 // BBlock
5306 // jump %merge
5307 //
5308 // %C:
5309 // CBlock
5310 // jump %D
5311 //
5312 // %D:
5313 // DBlock
5314 // jump %merge
5315 //
5316 // The OpSwitch instruction contains the jump labels for the default and other cases. Each
5317 // block either terminates with a jump to the merge block or the next block as fallthrough.
5318 //
5319 // // pre-switch block
5320 // OpSelectionMerge %merge None
5321 // OpSwitch %cond %C A %A B %B C %C D %D
5322 //
5323 // %A = OpLabel
5324 // ABlock
5325 // OpBranch %merge
5326 //
5327 // %B = OpLabel
5328 // BBlock
5329 // OpBranch %merge
5330 //
5331 // %C = OpLabel
5332 // CBlock
5333 // OpBranch %D
5334 //
5335 // %D = OpLabel
5336 // DBlock
5337 // OpBranch %merge
5338
5339 if (visit == PreVisit)
5340 {
5341 // Artificially add `if (true)` around switches as a driver bug workaround
5342 if (mCompileOptions.wrapSwitchInIfTrue)
5343 {
5344 const spirv::IdRef conditionValue = mBuilder.getBoolConstant(true);
5345 mBuilder.startConditional(2, false, false);
5346 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5347 const spirv::IdRef trueBlock = conditional->blockIds[0];
5348 const spirv::IdRef mergeBlock = conditional->blockIds[1];
5349 mBuilder.writeBranchConditional(conditionValue, trueBlock, mergeBlock, mergeBlock);
5350 }
5351
5352 // Don't add an entry to the stack. The condition will create one, which we won't pop.
5353 return true;
5354 }
5355
5356 // If the condition was just visited, evaluate it and create the switch instruction.
5357 if (visit == InVisit)
5358 {
5359 ASSERT(getLastTraversedChildIndex(visit) == 0);
5360
5361 const spirv::IdRef conditionValue =
5362 accessChainLoad(&mNodeData.back(), node->getInit()->getType(), nullptr);
5363
5364 // First, need to find out how many blocks are there in the switch.
5365 const TIntermSequence &statements = *node->getStatementList()->getSequence();
5366 bool lastWasCase = true;
5367 size_t blockIndex = 0;
5368
5369 size_t defaultBlockIndex = std::numeric_limits<size_t>::max();
5370 TVector<uint32_t> caseValues;
5371 TVector<size_t> caseBlockIndices;
5372
5373 for (TIntermNode *statement : statements)
5374 {
5375 TIntermCase *caseLabel = statement->getAsCaseNode();
5376 const bool isCaseLabel = caseLabel != nullptr;
5377
5378 if (isCaseLabel)
5379 {
5380 // For every case label, remember its block index. This is used later to generate
5381 // the OpSwitch instruction.
5382 if (caseLabel->hasCondition())
5383 {
5384 // All switch conditions are literals.
5385 TIntermConstantUnion *condition =
5386 caseLabel->getCondition()->getAsConstantUnion();
5387 ASSERT(condition != nullptr);
5388
5389 TConstantUnion caseValue;
5390 if (condition->getType().getBasicType() == EbtYuvCscStandardEXT)
5391 {
5392 caseValue.setUConst(
5393 condition->getConstantValue()->getYuvCscStandardEXTConst());
5394 }
5395 else
5396 {
5397 bool valid = caseValue.cast(EbtUInt, *condition->getConstantValue());
5398 ASSERT(valid);
5399 }
5400
5401 caseValues.push_back(caseValue.getUConst());
5402 caseBlockIndices.push_back(blockIndex);
5403 }
5404 else
5405 {
5406 // Remember the block index of the default case.
5407 defaultBlockIndex = blockIndex;
5408 }
5409 lastWasCase = true;
5410 }
5411 else if (lastWasCase)
5412 {
5413 // Every time a non-case node is visited and the previous statement was a case node,
5414 // it's a new block.
5415 ++blockIndex;
5416 lastWasCase = false;
5417 }
5418 }
5419
5420 // Block count is the number of blocks based on cases + 1 for the merge block.
5421 const size_t blockCount = blockIndex + 1;
5422 mBuilder.startConditional(blockCount, false, true);
5423
5424 // Generate the switch instructions.
5425 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5426
5427 // Generate the list of caseValue->blockIndex mapping used by the OpSwitch instruction. If
5428 // the switch ends in a number of cases with no statements following them, they will
5429 // naturally jump to the merge block!
5430 spirv::PairLiteralIntegerIdRefList switchTargets;
5431
5432 for (size_t caseIndex = 0; caseIndex < caseValues.size(); ++caseIndex)
5433 {
5434 uint32_t value = caseValues[caseIndex];
5435 size_t caseBlockIndex = caseBlockIndices[caseIndex];
5436
5437 switchTargets.push_back(
5438 {spirv::LiteralInteger(value), conditional->blockIds[caseBlockIndex]});
5439 }
5440
5441 const spirv::IdRef mergeBlock = conditional->blockIds.back();
5442 const spirv::IdRef defaultBlock = defaultBlockIndex <= caseValues.size()
5443 ? conditional->blockIds[defaultBlockIndex]
5444 : mergeBlock;
5445
5446 mBuilder.writeSwitch(conditionValue, defaultBlock, switchTargets, mergeBlock);
5447 return true;
5448 }
5449
5450 // Terminate the last block if not already and end the conditional.
5451 mBuilder.writeSwitchCaseBlockEnd();
5452 mBuilder.endConditional();
5453
5454 if (mCompileOptions.wrapSwitchInIfTrue)
5455 {
5456 mBuilder.writeBranchConditionalBlockEnd();
5457 mBuilder.endConditional();
5458 }
5459
5460 return true;
5461 }
5462
visitCase(Visit visit,TIntermCase * node)5463 bool OutputSPIRVTraverser::visitCase(Visit visit, TIntermCase *node)
5464 {
5465 ASSERT(visit == PreVisit);
5466
5467 mNodeData.emplace_back();
5468
5469 TIntermBlock *parent = getParentNode()->getAsBlock();
5470 const size_t childIndex = getParentChildIndex(PreVisit);
5471
5472 ASSERT(parent);
5473 const TIntermSequence &parentStatements = *parent->getSequence();
5474
5475 // Check the previous statement. If it was not a |case|, then a new block is being started so
5476 // handle fallthrough:
5477 //
5478 // ...
5479 // statement;
5480 // case X: <--- end the previous block here
5481 // case Y:
5482 //
5483 //
5484 if (childIndex > 0 && parentStatements[childIndex - 1]->getAsCaseNode() == nullptr)
5485 {
5486 mBuilder.writeSwitchCaseBlockEnd();
5487 }
5488
5489 // Don't traverse the condition, as it was processed in visitSwitch.
5490 return false;
5491 }
5492
visitBlock(Visit visit,TIntermBlock * node)5493 bool OutputSPIRVTraverser::visitBlock(Visit visit, TIntermBlock *node)
5494 {
5495 // If global block, nothing to do.
5496 if (getCurrentTraversalDepth() == 0)
5497 {
5498 return true;
5499 }
5500
5501 // Any construct that needs code blocks must have already handled creating the necessary blocks
5502 // and setting the right one "current". If there's a block opened in GLSL for scoping reasons,
5503 // it's ignored here as there are no scopes within a function in SPIR-V.
5504 if (visit == PreVisit)
5505 {
5506 return node->getChildCount() > 0;
5507 }
5508
5509 // Any node that needed to generate code has already done so, just clean up its data. If
5510 // the child node has no effect, it's automatically discarded (such as variable.field[n].x,
5511 // side effects of n already having generated code).
5512 //
5513 // Blocks inside blocks like:
5514 //
5515 // {
5516 // statement;
5517 // {
5518 // statement2;
5519 // }
5520 // }
5521 //
5522 // don't generate nodes.
5523 const size_t childIndex = getLastTraversedChildIndex(visit);
5524 const TIntermSequence &statements = *node->getSequence();
5525
5526 if (statements[childIndex]->getAsBlock() == nullptr)
5527 {
5528 mNodeData.pop_back();
5529 }
5530
5531 return true;
5532 }
5533
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)5534 bool OutputSPIRVTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
5535 {
5536 if (visit == PreVisit)
5537 {
5538 return true;
5539 }
5540
5541 const TFunction *function = node->getFunction();
5542
5543 ASSERT(mFunctionIdMap.count(function) > 0);
5544 const FunctionIds &ids = mFunctionIdMap[function];
5545
5546 // After the prototype is visited, generate the initial code for the function.
5547 if (visit == InVisit)
5548 {
5549 // Declare the function.
5550 spirv::WriteFunction(mBuilder.getSpirvFunctions(), ids.returnTypeId, ids.functionId,
5551 spv::FunctionControlMaskNone, ids.functionTypeId);
5552
5553 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5554 {
5555 const TVariable *paramVariable = function->getParam(paramIndex);
5556
5557 const spirv::IdRef paramId =
5558 mBuilder.getNewId(mBuilder.getDecorations(paramVariable->getType()));
5559 spirv::WriteFunctionParameter(mBuilder.getSpirvFunctions(),
5560 ids.parameterTypeIds[paramIndex], paramId);
5561
5562 // Remember the id of the variable for future look up.
5563 ASSERT(mSymbolIdMap.count(paramVariable) == 0);
5564 mSymbolIdMap[paramVariable] = paramId;
5565
5566 mBuilder.writeDebugName(paramId, mBuilder.getName(paramVariable).data());
5567 }
5568
5569 mBuilder.startNewFunction(ids.functionId, function);
5570
5571 // For main(), add a non-semantic instruction at the beginning for any initialization code
5572 // the transformer may want to add.
5573 if (ids.functionId == vk::spirv::kIdEntryPoint &&
5574 mCompiler->getShaderType() != GL_COMPUTE_SHADER)
5575 {
5576 ASSERT(function->isMain());
5577 mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticEnter);
5578 }
5579
5580 mCurrentFunctionId = ids.functionId;
5581
5582 return true;
5583 }
5584
5585 // If no explicit return was specified, add one automatically here.
5586 if (!mBuilder.isCurrentFunctionBlockTerminated())
5587 {
5588 if (function->getReturnType().getBasicType() == EbtVoid)
5589 {
5590 switch (ids.functionId)
5591 {
5592 case vk::spirv::kIdEntryPoint:
5593 // For main(), add a non-semantic instruction at the end of the shader.
5594 markVertexOutputOnShaderEnd();
5595 break;
5596 case vk::spirv::kIdXfbEmulationCaptureFunction:
5597 // For the transform feedback emulation capture function, add a non-semantic
5598 // instruction before return for the transformer to fill in as necessary.
5599 mBuilder.writeNonSemanticInstruction(
5600 vk::spirv::kNonSemanticTransformFeedbackEmulation);
5601 break;
5602 }
5603
5604 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
5605 }
5606 else
5607 {
5608 // GLSL allows functions expecting a return value to miss a return. In that case,
5609 // return a null constant.
5610 const TType &returnType = function->getReturnType();
5611 spirv::IdRef nullConstant;
5612 if (returnType.isScalar() && !returnType.isArray())
5613 {
5614 switch (function->getReturnType().getBasicType())
5615 {
5616 case EbtFloat:
5617 nullConstant = mBuilder.getFloatConstant(0);
5618 break;
5619 case EbtUInt:
5620 nullConstant = mBuilder.getUintConstant(0);
5621 break;
5622 case EbtInt:
5623 nullConstant = mBuilder.getIntConstant(0);
5624 break;
5625 default:
5626 break;
5627 }
5628 }
5629 if (!nullConstant.valid())
5630 {
5631 nullConstant = mBuilder.getNullConstant(ids.returnTypeId);
5632 }
5633 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), nullConstant);
5634 }
5635 mBuilder.terminateCurrentFunctionBlock();
5636 }
5637
5638 mBuilder.assembleSpirvFunctionBlocks();
5639
5640 // End the function
5641 spirv::WriteFunctionEnd(mBuilder.getSpirvFunctions());
5642
5643 mCurrentFunctionId = {};
5644
5645 return true;
5646 }
5647
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)5648 bool OutputSPIRVTraverser::visitGlobalQualifierDeclaration(Visit visit,
5649 TIntermGlobalQualifierDeclaration *node)
5650 {
5651 if (node->isPrecise())
5652 {
5653 // Nothing to do for |precise|.
5654 return false;
5655 }
5656
5657 // Global qualifier declarations apply to variables that are already declared. Invariant simply
5658 // adds a decoration to the variable declaration, which can be done right away. Note that
5659 // invariant cannot be applied to block members like this, except for gl_PerVertex built-ins,
5660 // which are applied to the members directly by DeclarePerVertexBlocks.
5661 ASSERT(node->isInvariant());
5662
5663 const TVariable *variable = &node->getSymbol()->variable();
5664 ASSERT(mSymbolIdMap.count(variable) > 0);
5665
5666 const spirv::IdRef variableId = mSymbolIdMap[variable];
5667
5668 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationInvariant, {});
5669
5670 return false;
5671 }
5672
visitFunctionPrototype(TIntermFunctionPrototype * node)5673 void OutputSPIRVTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
5674 {
5675 const TFunction *function = node->getFunction();
5676
5677 // If the function was previously forward declared, skip this.
5678 if (mFunctionIdMap.count(function) > 0)
5679 {
5680 return;
5681 }
5682
5683 FunctionIds ids;
5684
5685 // Declare the function type
5686 ids.returnTypeId = mBuilder.getTypeData(function->getReturnType(), {}).id;
5687
5688 spirv::IdRefList paramTypeIds;
5689 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5690 {
5691 const TType ¶mType = function->getParam(paramIndex)->getType();
5692
5693 spirv::IdRef paramId = mBuilder.getTypeData(paramType, {}).id;
5694
5695 // const function parameters are intermediate values, while the rest are "variables"
5696 // with the Function storage class.
5697 if (paramType.getQualifier() != EvqParamConst)
5698 {
5699 const spv::StorageClass storageClass = IsOpaqueType(paramType.getBasicType())
5700 ? spv::StorageClassUniformConstant
5701 : spv::StorageClassFunction;
5702 paramId = mBuilder.getTypePointerId(paramId, storageClass);
5703 }
5704
5705 ids.parameterTypeIds.push_back(paramId);
5706 }
5707
5708 ids.functionTypeId = mBuilder.getFunctionTypeId(ids.returnTypeId, ids.parameterTypeIds);
5709
5710 // Allocate an id for the function up-front.
5711 //
5712 // Apply decorations to the return value of the function by applying them to the OpFunction
5713 // instruction.
5714 //
5715 // Note that some functions have predefined ids.
5716 if (function->isMain())
5717 {
5718 ids.functionId = spirv::IdRef(vk::spirv::kIdEntryPoint);
5719 }
5720 else
5721 {
5722 ids.functionId = mBuilder.getReservedOrNewId(
5723 function->uniqueId(), mBuilder.getDecorations(function->getReturnType()));
5724 }
5725
5726 // Remember the id of the function for future look up.
5727 mFunctionIdMap[function] = ids;
5728 }
5729
visitAggregate(Visit visit,TIntermAggregate * node)5730 bool OutputSPIRVTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
5731 {
5732 // Constants are expected to be folded. However, large constructors (such as arrays) are not
5733 // folded and are handled here.
5734 ASSERT(node->getOp() == EOpConstruct || !node->hasConstantValue());
5735
5736 if (visit == PreVisit)
5737 {
5738 mNodeData.emplace_back();
5739 return true;
5740 }
5741
5742 // Keep the parameters on the stack. If a function call contains out or inout parameters, we
5743 // need to know the access chains for the eventual write back to them.
5744 if (visit == InVisit)
5745 {
5746 return true;
5747 }
5748
5749 // Expect to have accumulated as many parameters as the node requires.
5750 ASSERT(mNodeData.size() > node->getChildCount());
5751
5752 const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5753 spirv::IdRef result;
5754
5755 switch (node->getOp())
5756 {
5757 case EOpConstruct:
5758 // Construct a value out of the accumulated parameters.
5759 result = createConstructor(node, resultTypeId);
5760 break;
5761 case EOpCallFunctionInAST:
5762 // Create a call to the function.
5763 result = createFunctionCall(node, resultTypeId);
5764 break;
5765
5766 // For barrier functions the scope is device, or with the Vulkan memory model, the queue
5767 // family. We don't use the Vulkan memory model.
5768 case EOpBarrier:
5769 spirv::WriteControlBarrier(
5770 mBuilder.getSpirvCurrentFunctionBlock(),
5771 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5772 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5773 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5774 spv::MemorySemanticsAcquireReleaseMask));
5775 break;
5776 case EOpBarrierTCS:
5777 // Note: The memory scope and semantics are different with the Vulkan memory model,
5778 // which is not supported.
5779 spirv::WriteControlBarrier(mBuilder.getSpirvCurrentFunctionBlock(),
5780 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5781 mBuilder.getUintConstant(spv::ScopeInvocation),
5782 mBuilder.getUintConstant(spv::MemorySemanticsMaskNone));
5783 break;
5784 case EOpMemoryBarrier:
5785 case EOpGroupMemoryBarrier:
5786 {
5787 const spv::Scope scope =
5788 node->getOp() == EOpMemoryBarrier ? spv::ScopeDevice : spv::ScopeWorkgroup;
5789 spirv::WriteMemoryBarrier(
5790 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(scope),
5791 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5792 spv::MemorySemanticsWorkgroupMemoryMask |
5793 spv::MemorySemanticsImageMemoryMask |
5794 spv::MemorySemanticsAcquireReleaseMask));
5795 break;
5796 }
5797 case EOpMemoryBarrierBuffer:
5798 spirv::WriteMemoryBarrier(
5799 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5800 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5801 spv::MemorySemanticsAcquireReleaseMask));
5802 break;
5803 case EOpMemoryBarrierImage:
5804 spirv::WriteMemoryBarrier(
5805 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5806 mBuilder.getUintConstant(spv::MemorySemanticsImageMemoryMask |
5807 spv::MemorySemanticsAcquireReleaseMask));
5808 break;
5809 case EOpMemoryBarrierShared:
5810 spirv::WriteMemoryBarrier(
5811 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5812 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5813 spv::MemorySemanticsAcquireReleaseMask));
5814 break;
5815 case EOpMemoryBarrierAtomicCounter:
5816 // Atomic counters are emulated.
5817 UNREACHABLE();
5818 break;
5819
5820 case EOpEmitVertex:
5821 if (mCurrentFunctionId == vk::spirv::kIdEntryPoint)
5822 {
5823 // Add a non-semantic instruction before EmitVertex.
5824 markVertexOutputOnEmitVertex();
5825 }
5826 spirv::WriteEmitVertex(mBuilder.getSpirvCurrentFunctionBlock());
5827 break;
5828 case EOpEndPrimitive:
5829 spirv::WriteEndPrimitive(mBuilder.getSpirvCurrentFunctionBlock());
5830 break;
5831
5832 case EOpBeginInvocationInterlockARB:
5833 // Set up a "pixel_interlock_ordered" execution mode, as that is the default
5834 // interlocked execution mode in GLSL, and we don't currently expose an option to change
5835 // that.
5836 mBuilder.addExtension(SPIRVExtensions::FragmentShaderInterlockARB);
5837 mBuilder.addCapability(spv::CapabilityFragmentShaderPixelInterlockEXT);
5838 mBuilder.addExecutionMode(spv::ExecutionMode::ExecutionModePixelInterlockOrderedEXT);
5839 // Compile GL_ARB_fragment_shader_interlock to SPV_EXT_fragment_shader_interlock.
5840 spirv::WriteBeginInvocationInterlockEXT(mBuilder.getSpirvCurrentFunctionBlock());
5841 break;
5842
5843 case EOpEndInvocationInterlockARB:
5844 // Compile GL_ARB_fragment_shader_interlock to SPV_EXT_fragment_shader_interlock.
5845 spirv::WriteEndInvocationInterlockEXT(mBuilder.getSpirvCurrentFunctionBlock());
5846 break;
5847
5848 default:
5849 result = visitOperator(node, resultTypeId);
5850 break;
5851 }
5852
5853 // Pop the parameters.
5854 mNodeData.resize(mNodeData.size() - node->getChildCount());
5855
5856 // Keep the result as rvalue.
5857 nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5858
5859 return false;
5860 }
5861
visitDeclaration(Visit visit,TIntermDeclaration * node)5862 bool OutputSPIRVTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
5863 {
5864 const TIntermSequence &sequence = *node->getSequence();
5865
5866 // Enforced by ValidateASTOptions::validateMultiDeclarations.
5867 ASSERT(sequence.size() == 1);
5868
5869 // Declare specialization constants especially; they don't require processing the left and right
5870 // nodes, and they are like constant declarations with special instructions and decorations.
5871 const TQualifier qualifier = sequence.front()->getAsTyped()->getType().getQualifier();
5872 if (qualifier == EvqSpecConst)
5873 {
5874 declareSpecConst(node);
5875 return false;
5876 }
5877 // Similarly, constant declarations are turned into actual constants.
5878 if (qualifier == EvqConst)
5879 {
5880 declareConst(node);
5881 return false;
5882 }
5883
5884 // Skip redeclaration of builtins. They will correctly declare as built-in on first use.
5885 if (mInGlobalScope &&
5886 (qualifier == EvqClipDistance || qualifier == EvqCullDistance || qualifier == EvqFragDepth))
5887 {
5888 return false;
5889 }
5890
5891 if (!mInGlobalScope && visit == PreVisit)
5892 {
5893 mNodeData.emplace_back();
5894 }
5895
5896 mIsSymbolBeingDeclared = visit == PreVisit;
5897
5898 if (visit != PostVisit)
5899 {
5900 return true;
5901 }
5902
5903 TIntermSymbol *symbol = sequence.front()->getAsSymbolNode();
5904 spirv::IdRef initializerId;
5905 bool initializeWithDeclaration = false;
5906 bool needsQuantizeTo16 = false;
5907
5908 // Handle declarations with initializer.
5909 if (symbol == nullptr)
5910 {
5911 TIntermBinary *assign = sequence.front()->getAsBinaryNode();
5912 ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
5913
5914 symbol = assign->getLeft()->getAsSymbolNode();
5915 ASSERT(symbol != nullptr);
5916
5917 // In SPIR-V, it's only possible to initialize a variable together with its declaration if
5918 // the initializer is a constant or a global variable. We ignore the global variable case
5919 // to avoid tracking whether the variable has been modified since the beginning of the
5920 // function. Since variable declarations are always placed at the beginning of the function
5921 // in SPIR-V, it would be wrong for example to initialize |var| below with the global
5922 // variable at declaration time:
5923 //
5924 // vec4 global = A;
5925 // void f()
5926 // {
5927 // global = B;
5928 // {
5929 // vec4 var = global;
5930 // }
5931 // }
5932 //
5933 // So the initializer is only used when declaring a variable when it's a constant
5934 // expression. Note that if the variable being declared is itself global (and the
5935 // initializer is not constant), a previous AST transformation (DeferGlobalInitializers)
5936 // makes sure their initialization is deferred to the beginning of main.
5937 //
5938 // Additionally, if the variable is being defined inside a loop, the initializer is not used
5939 // as that would prevent it from being reinitialized in the next iteration of the loop.
5940
5941 TIntermTyped *initializer = assign->getRight();
5942 initializeWithDeclaration =
5943 !mBuilder.isInLoop() &&
5944 (initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
5945
5946 if (initializeWithDeclaration)
5947 {
5948 // If a constant, take the Id directly.
5949 initializerId = mNodeData.back().baseId;
5950 }
5951 else
5952 {
5953 // Otherwise generate code to load from right hand side expression.
5954 initializerId = accessChainLoad(&mNodeData.back(), symbol->getType(), nullptr);
5955
5956 // Workaround for issuetracker.google.com/274859104
5957 // ARM SpirV compiler may utilize the RelaxedPrecision of mediump float,
5958 // and chooses to not cast mediump float to 16 bit. This causes deqp test
5959 // dEQP-GLES2.functional.shaders.algorithm.rgb_to_hsl_vertex failed.
5960 // The reason is that GLSL shader code expects below condition to be true:
5961 // mediump float a == mediump float b;
5962 // However, the condition is false after translating to SpirV
5963 // due to one of them is 32 bit, and the other is 16 bit.
5964 // To resolve the deqp test failure, we will add an OpQuantizeToF16
5965 // SpirV instruction to explicitly cast mediump float scalar or mediump float
5966 // vector to 16 bit, if the right-hand-side is a highp float.
5967 if (mCompileOptions.castMediumpFloatTo16Bit)
5968 {
5969 const TType leftType = assign->getLeft()->getType();
5970 const TType rightType = assign->getRight()->getType();
5971 const TPrecision leftPrecision = leftType.getPrecision();
5972 const TPrecision rightPrecision = rightType.getPrecision();
5973 const bool isLeftScalarFloat = leftType.isScalarFloat();
5974 const bool isLeftVectorFloat = leftType.isVector() && !leftType.isVectorArray() &&
5975 leftType.getBasicType() == EbtFloat;
5976
5977 if (leftPrecision == TPrecision::EbpMedium &&
5978 rightPrecision == TPrecision::EbpHigh &&
5979 (isLeftScalarFloat || isLeftVectorFloat))
5980 {
5981 needsQuantizeTo16 = true;
5982 }
5983 }
5984 }
5985
5986 // Clean up the initializer data.
5987 mNodeData.pop_back();
5988 }
5989
5990 const TType &type = symbol->getType();
5991 const TVariable *variable = &symbol->variable();
5992
5993 // If this is just a struct declaration (and not a variable declaration), don't declare the
5994 // struct up-front and let it be lazily defined. If the struct is only used inside an interface
5995 // block for example, this avoids it being doubly defined (once with the unspecified block
5996 // storage and once with interface block's).
5997 if (type.isStructSpecifier() && variable->symbolType() == SymbolType::Empty)
5998 {
5999 return false;
6000 }
6001
6002 const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
6003
6004 spv::StorageClass storageClass =
6005 GetStorageClass(mCompileOptions, type, mCompiler->getShaderType());
6006
6007 SpirvDecorations decorations = mBuilder.getDecorations(type);
6008 if (mBuilder.isInvariantOutput(type))
6009 {
6010 // Apply the Invariant decoration to output variables if specified or if globally enabled.
6011 decorations.push_back(spv::DecorationInvariant);
6012 }
6013 // Apply the declared memory qualifiers.
6014 TMemoryQualifier memoryQualifier = type.getMemoryQualifier();
6015 if (memoryQualifier.coherent)
6016 {
6017 decorations.push_back(spv::DecorationCoherent);
6018 }
6019 if (memoryQualifier.volatileQualifier)
6020 {
6021 decorations.push_back(spv::DecorationVolatile);
6022 }
6023 if (memoryQualifier.restrictQualifier)
6024 {
6025 decorations.push_back(spv::DecorationRestrict);
6026 }
6027 if (memoryQualifier.readonly)
6028 {
6029 decorations.push_back(spv::DecorationNonWritable);
6030 }
6031 if (memoryQualifier.writeonly)
6032 {
6033 decorations.push_back(spv::DecorationNonReadable);
6034 }
6035
6036 const spirv::IdRef variableId = mBuilder.declareVariable(
6037 typeId, storageClass, decorations, initializeWithDeclaration ? &initializerId : nullptr,
6038 mBuilder.getName(variable).data(), &variable->uniqueId());
6039
6040 if (!initializeWithDeclaration && initializerId.valid())
6041 {
6042 // If not initializing at the same time as the declaration, issue a store
6043 if (needsQuantizeTo16)
6044 {
6045 // Insert OpQuantizeToF16 instruction to explicitly cast mediump float to 16 bit before
6046 // issuing an OpStore instruction.
6047 const spirv::IdRef quantizeToF16Result =
6048 mBuilder.getNewId(mBuilder.getDecorations(symbol->getType()));
6049 spirv::WriteQuantizeToF16(mBuilder.getSpirvCurrentFunctionBlock(), typeId,
6050 quantizeToF16Result, initializerId);
6051 initializerId = quantizeToF16Result;
6052 }
6053 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), variableId, initializerId,
6054 nullptr);
6055 }
6056
6057 const bool isShaderInOut = IsShaderIn(type.getQualifier()) || IsShaderOut(type.getQualifier());
6058 const bool isInterfaceBlock = type.getBasicType() == EbtInterfaceBlock;
6059
6060 // Add decorations, which apply to the element type of arrays, if array.
6061 spirv::IdRef nonArrayTypeId = typeId;
6062 if (type.isArray() && (isShaderInOut || isInterfaceBlock))
6063 {
6064 SpirvType elementType = mBuilder.getSpirvType(type, {});
6065 elementType.arraySizes = {};
6066 nonArrayTypeId = mBuilder.getSpirvTypeData(elementType, nullptr).id;
6067 }
6068
6069 if (isShaderInOut)
6070 {
6071 if (IsShaderIoBlock(type.getQualifier()) && type.isInterfaceBlock())
6072 {
6073 // For gl_PerVertex in particular, write the necessary BuiltIn decorations
6074 if (type.getQualifier() == EvqPerVertexIn || type.getQualifier() == EvqPerVertexOut)
6075 {
6076 mBuilder.writePerVertexBuiltIns(type, nonArrayTypeId);
6077 }
6078
6079 // I/O blocks are decorated with Block
6080 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId,
6081 spv::DecorationBlock, {});
6082 }
6083 else if (type.getQualifier() == EvqPatchIn || type.getQualifier() == EvqPatchOut)
6084 {
6085 // Tessellation shaders can have their input or output qualified with |patch|. For I/O
6086 // blocks, the members are decorated instead.
6087 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationPatch,
6088 {});
6089 }
6090 }
6091 else if (isInterfaceBlock)
6092 {
6093 // For uniform and buffer variables, with SPIR-V 1.3 add Block and BufferBlock decorations
6094 // respectively. With SPIR-V 1.4, always add Block.
6095 const spv::Decoration decoration =
6096 mCompileOptions.emitSPIRV14 || type.getQualifier() == EvqUniform
6097 ? spv::DecorationBlock
6098 : spv::DecorationBufferBlock;
6099 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId, decoration, {});
6100
6101 if (type.getQualifier() == EvqBuffer && !memoryQualifier.restrictQualifier &&
6102 mCompileOptions.aliasedUnlessRestrict)
6103 {
6104 // If GLSL does not specify the SSBO has restrict memory qualifier, assume the
6105 // memory qualifier is aliased
6106 // issuetracker.google.com/266235549
6107 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationAliased,
6108 {});
6109 }
6110 }
6111 else if (IsImage(type.getBasicType()) && type.getQualifier() == EvqUniform)
6112 {
6113 // If GLSL does not specify the image has restrict memory qualifier, assume the memory
6114 // qualifier is aliased
6115 // issuetracker.google.com/266235549
6116 if (!memoryQualifier.restrictQualifier && mCompileOptions.aliasedUnlessRestrict)
6117 {
6118 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationAliased,
6119 {});
6120 }
6121 }
6122
6123 // Write DescriptorSet, Binding, Location etc decorations if necessary.
6124 mBuilder.writeInterfaceVariableDecorations(type, variableId);
6125
6126 // Remember the id of the variable for future look up. For interface blocks, also remember the
6127 // id of the interface block.
6128 ASSERT(mSymbolIdMap.count(variable) == 0);
6129 mSymbolIdMap[variable] = variableId;
6130
6131 if (type.isInterfaceBlock())
6132 {
6133 ASSERT(mSymbolIdMap.count(type.getInterfaceBlock()) == 0);
6134 mSymbolIdMap[type.getInterfaceBlock()] = variableId;
6135 }
6136
6137 return false;
6138 }
6139
GetLoopBlocks(const SpirvConditional * conditional,TLoopType loopType,bool hasCondition,spirv::IdRef * headerBlock,spirv::IdRef * condBlock,spirv::IdRef * bodyBlock,spirv::IdRef * continueBlock,spirv::IdRef * mergeBlock)6140 void GetLoopBlocks(const SpirvConditional *conditional,
6141 TLoopType loopType,
6142 bool hasCondition,
6143 spirv::IdRef *headerBlock,
6144 spirv::IdRef *condBlock,
6145 spirv::IdRef *bodyBlock,
6146 spirv::IdRef *continueBlock,
6147 spirv::IdRef *mergeBlock)
6148 {
6149 // The order of the blocks is for |for| and |while|:
6150 //
6151 // %header %cond [optional] %body %continue %merge
6152 //
6153 // and for |do-while|:
6154 //
6155 // %header %body %cond %merge
6156 //
6157 // Note that the |break| target is always the last block and the |continue| target is the one
6158 // before last.
6159 //
6160 // If %continue is not present, all jumps are made to %cond (which is necessarily present).
6161 // If %cond is not present, all jumps are made to %body instead.
6162
6163 size_t nextBlock = 0;
6164 *headerBlock = conditional->blockIds[nextBlock++];
6165 // %cond, if any is after header except for |do-while|.
6166 if (loopType != ELoopDoWhile && hasCondition)
6167 {
6168 *condBlock = conditional->blockIds[nextBlock++];
6169 }
6170 *bodyBlock = conditional->blockIds[nextBlock++];
6171 // After the block is either %cond or %continue based on |do-while| or not.
6172 if (loopType != ELoopDoWhile)
6173 {
6174 *continueBlock = conditional->blockIds[nextBlock++];
6175 }
6176 else
6177 {
6178 *condBlock = conditional->blockIds[nextBlock++];
6179 }
6180 *mergeBlock = conditional->blockIds[nextBlock++];
6181
6182 ASSERT(nextBlock == conditional->blockIds.size());
6183
6184 if (!continueBlock->valid())
6185 {
6186 ASSERT(condBlock->valid());
6187 *continueBlock = *condBlock;
6188 }
6189 if (!condBlock->valid())
6190 {
6191 *condBlock = *bodyBlock;
6192 }
6193 }
6194
visitLoop(Visit visit,TIntermLoop * node)6195 bool OutputSPIRVTraverser::visitLoop(Visit visit, TIntermLoop *node)
6196 {
6197 // There are three kinds of loops, and they translate as such:
6198 //
6199 // for (init; cond; expr) body;
6200 //
6201 // // pre-loop block
6202 // init
6203 // OpBranch %header
6204 //
6205 // %header = OpLabel
6206 // OpLoopMerge %merge %continue None
6207 // OpBranch %cond
6208 //
6209 // // Note: if cond doesn't exist, this section is not generated. The above
6210 // // OpBranch would jump directly to %body.
6211 // %cond = OpLabel
6212 // %v = cond
6213 // OpBranchConditional %v %body %merge None
6214 //
6215 // %body = OpLabel
6216 // body
6217 // OpBranch %continue
6218 //
6219 // %continue = OpLabel
6220 // expr
6221 // OpBranch %header
6222 //
6223 // // post-loop block
6224 // %merge = OpLabel
6225 //
6226 //
6227 // while (cond) body;
6228 //
6229 // // pre-for block
6230 // OpBranch %header
6231 //
6232 // %header = OpLabel
6233 // OpLoopMerge %merge %continue None
6234 // OpBranch %cond
6235 //
6236 // %cond = OpLabel
6237 // %v = cond
6238 // OpBranchConditional %v %body %merge None
6239 //
6240 // %body = OpLabel
6241 // body
6242 // OpBranch %continue
6243 //
6244 // %continue = OpLabel
6245 // OpBranch %header
6246 //
6247 // // post-loop block
6248 // %merge = OpLabel
6249 //
6250 //
6251 // do body; while (cond);
6252 //
6253 // // pre-for block
6254 // OpBranch %header
6255 //
6256 // %header = OpLabel
6257 // OpLoopMerge %merge %cond None
6258 // OpBranch %body
6259 //
6260 // %body = OpLabel
6261 // body
6262 // OpBranch %cond
6263 //
6264 // %cond = OpLabel
6265 // %v = cond
6266 // OpBranchConditional %v %header %merge None
6267 //
6268 // // post-loop block
6269 // %merge = OpLabel
6270 //
6271
6272 // The order of the blocks is not necessarily the same as traversed, so it's much simpler if
6273 // this function enforces traversal in the right order.
6274 ASSERT(visit == PreVisit);
6275 mNodeData.emplace_back();
6276
6277 const TLoopType loopType = node->getType();
6278
6279 // The init statement of a for loop is placed in the previous block, so continue generating code
6280 // as-is until that statement is done.
6281 if (node->getInit())
6282 {
6283 ASSERT(loopType == ELoopFor);
6284 node->getInit()->traverse(this);
6285 mNodeData.pop_back();
6286 }
6287
6288 const bool hasCondition = node->getCondition() != nullptr;
6289
6290 // Once the init node is visited, if any, we need to set up the loop.
6291 //
6292 // For |for| and |while|, we need %header, %body, %continue and %merge. For |do-while|, we
6293 // need %header, %body and %merge. If condition is present, an additional %cond block is
6294 // needed in each case.
6295 const size_t blockCount = (loopType == ELoopDoWhile ? 3 : 4) + (hasCondition ? 1 : 0);
6296 mBuilder.startConditional(blockCount, true, true);
6297
6298 // Generate the %header block.
6299 const SpirvConditional *conditional = mBuilder.getCurrentConditional();
6300
6301 spirv::IdRef headerBlock, condBlock, bodyBlock, continueBlock, mergeBlock;
6302 GetLoopBlocks(conditional, loopType, hasCondition, &headerBlock, &condBlock, &bodyBlock,
6303 &continueBlock, &mergeBlock);
6304
6305 mBuilder.writeLoopHeader(loopType == ELoopDoWhile ? bodyBlock : condBlock, continueBlock,
6306 mergeBlock);
6307
6308 // %cond, if any is after header except for |do-while|.
6309 if (loopType != ELoopDoWhile && hasCondition)
6310 {
6311 node->getCondition()->traverse(this);
6312
6313 // Generate the branch at the end of the %cond block.
6314 const spirv::IdRef conditionValue =
6315 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6316 mBuilder.writeLoopConditionEnd(conditionValue, bodyBlock, mergeBlock);
6317
6318 mNodeData.pop_back();
6319 }
6320
6321 // Next comes %body.
6322 {
6323 node->getBody()->traverse(this);
6324
6325 // Generate the branch at the end of the %body block.
6326 mBuilder.writeLoopBodyEnd(continueBlock);
6327 }
6328
6329 switch (loopType)
6330 {
6331 case ELoopFor:
6332 // For |for| loops, the expression is placed after the body and acts as the continue
6333 // block.
6334 if (node->getExpression())
6335 {
6336 node->getExpression()->traverse(this);
6337 mNodeData.pop_back();
6338 }
6339
6340 // Generate the branch at the end of the %continue block.
6341 mBuilder.writeLoopContinueEnd(headerBlock);
6342 break;
6343
6344 case ELoopWhile:
6345 // |for| loops have the expression in the continue block and |do-while| loops have their
6346 // condition block act as the loop's continue block. |while| loops need a branch-only
6347 // continue loop, which is generated here.
6348 mBuilder.writeLoopContinueEnd(headerBlock);
6349 break;
6350
6351 case ELoopDoWhile:
6352 // For |do-while|, %cond comes last.
6353 ASSERT(hasCondition);
6354 node->getCondition()->traverse(this);
6355
6356 // Generate the branch at the end of the %cond block.
6357 const spirv::IdRef conditionValue =
6358 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6359 mBuilder.writeLoopConditionEnd(conditionValue, headerBlock, mergeBlock);
6360
6361 mNodeData.pop_back();
6362 break;
6363 }
6364
6365 // Pop from the conditional stack when done.
6366 mBuilder.endConditional();
6367
6368 // Don't traverse the children, that's done already.
6369 return false;
6370 }
6371
visitBranch(Visit visit,TIntermBranch * node)6372 bool OutputSPIRVTraverser::visitBranch(Visit visit, TIntermBranch *node)
6373 {
6374 if (visit == PreVisit)
6375 {
6376 mNodeData.emplace_back();
6377 return true;
6378 }
6379
6380 // There is only ever one child at most.
6381 ASSERT(visit != InVisit);
6382
6383 switch (node->getFlowOp())
6384 {
6385 case EOpKill:
6386 spirv::WriteKill(mBuilder.getSpirvCurrentFunctionBlock());
6387 mBuilder.terminateCurrentFunctionBlock();
6388 break;
6389 case EOpBreak:
6390 spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6391 mBuilder.getBreakTargetId());
6392 mBuilder.terminateCurrentFunctionBlock();
6393 break;
6394 case EOpContinue:
6395 spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6396 mBuilder.getContinueTargetId());
6397 mBuilder.terminateCurrentFunctionBlock();
6398 break;
6399 case EOpReturn:
6400 // Evaluate the expression if any, and return.
6401 if (node->getExpression() != nullptr)
6402 {
6403 ASSERT(mNodeData.size() >= 1);
6404
6405 const spirv::IdRef expressionValue =
6406 accessChainLoad(&mNodeData.back(), node->getExpression()->getType(), nullptr);
6407 mNodeData.pop_back();
6408
6409 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), expressionValue);
6410 mBuilder.terminateCurrentFunctionBlock();
6411 }
6412 else
6413 {
6414 if (mCurrentFunctionId == vk::spirv::kIdEntryPoint)
6415 {
6416 // For main(), add a non-semantic instruction at the end of the shader.
6417 markVertexOutputOnShaderEnd();
6418 }
6419 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
6420 mBuilder.terminateCurrentFunctionBlock();
6421 }
6422 break;
6423 default:
6424 UNREACHABLE();
6425 }
6426
6427 return true;
6428 }
6429
visitPreprocessorDirective(TIntermPreprocessorDirective * node)6430 void OutputSPIRVTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
6431 {
6432 // No preprocessor directives expected at this point.
6433 UNREACHABLE();
6434 }
6435
markVertexOutputOnShaderEnd()6436 void OutputSPIRVTraverser::markVertexOutputOnShaderEnd()
6437 {
6438 // Output happens in vertex and fragment stages at return from main.
6439 // In geometry shaders, it's done at EmitVertex.
6440 switch (mCompiler->getShaderType())
6441 {
6442 case GL_FRAGMENT_SHADER:
6443 case GL_VERTEX_SHADER:
6444 case GL_TESS_CONTROL_SHADER_EXT:
6445 case GL_TESS_EVALUATION_SHADER_EXT:
6446 mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticOutput);
6447 break;
6448 default:
6449 break;
6450 }
6451 }
6452
markVertexOutputOnEmitVertex()6453 void OutputSPIRVTraverser::markVertexOutputOnEmitVertex()
6454 {
6455 // Vertex output happens in the geometry stage at EmitVertex.
6456 if (mCompiler->getShaderType() == GL_GEOMETRY_SHADER)
6457 {
6458 mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticOutput);
6459 }
6460 }
6461
getSpirv()6462 spirv::Blob OutputSPIRVTraverser::getSpirv()
6463 {
6464 spirv::Blob result = mBuilder.getSpirv();
6465
6466 // Validate that correct SPIR-V was generated
6467 ASSERT(spirv::Validate(result));
6468
6469 #if ANGLE_DEBUG_SPIRV_GENERATION
6470 // Disassemble and log the generated SPIR-V for debugging.
6471 spvtools::SpirvTools spirvTools(mCompileOptions.emitSPIRV14 ? SPV_ENV_VULKAN_1_1_SPIRV_1_4
6472 : SPV_ENV_VULKAN_1_1);
6473 std::string readableSpirv;
6474 spirvTools.Disassemble(result, &readableSpirv,
6475 SPV_BINARY_TO_TEXT_OPTION_COMMENT | SPV_BINARY_TO_TEXT_OPTION_INDENT |
6476 SPV_BINARY_TO_TEXT_OPTION_NESTED_INDENT);
6477 fprintf(stderr, "%s\n", readableSpirv.c_str());
6478 #endif // ANGLE_DEBUG_SPIRV_GENERATION
6479
6480 return result;
6481 }
6482 } // anonymous namespace
6483
OutputSPIRV(TCompiler * compiler,TIntermBlock * root,const ShCompileOptions & compileOptions,const angle::HashMap<int,uint32_t> & uniqueToSpirvIdMap,uint32_t firstUnusedSpirvId)6484 bool OutputSPIRV(TCompiler *compiler,
6485 TIntermBlock *root,
6486 const ShCompileOptions &compileOptions,
6487 const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
6488 uint32_t firstUnusedSpirvId)
6489 {
6490 // Find the list of nodes that require NoContraction (as a result of |precise|).
6491 if (compiler->hasAnyPreciseType())
6492 {
6493 FindPreciseNodes(compiler, root);
6494 }
6495
6496 // Traverse the tree and generate SPIR-V instructions
6497 OutputSPIRVTraverser traverser(compiler, compileOptions, uniqueToSpirvIdMap,
6498 firstUnusedSpirvId);
6499 root->traverse(&traverser);
6500
6501 // Generate the final SPIR-V and store in the sink
6502 spirv::Blob spirvBlob = traverser.getSpirv();
6503 compiler->getInfoSink().obj.setBinary(std::move(spirvBlob));
6504
6505 return true;
6506 }
6507 } // namespace sh
6508