xref: /aosp_15_r20/external/angle/src/compiler/translator/spirv/OutputSPIRV.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // OutputSPIRV: Generate SPIR-V from the AST.
7 //
8 
9 #include "compiler/translator/spirv/OutputSPIRV.h"
10 
11 #include "angle_gl.h"
12 #include "common/debug.h"
13 #include "common/mathutil.h"
14 #include "common/spirv/spirv_instruction_builder_autogen.h"
15 #include "compiler/translator/Compiler.h"
16 #include "compiler/translator/StaticType.h"
17 #include "compiler/translator/spirv/BuildSPIRV.h"
18 #include "compiler/translator/tree_util/FindPreciseNodes.h"
19 #include "compiler/translator/tree_util/IntermTraverse.h"
20 
21 #include <cfloat>
22 
23 // Extended instructions
24 namespace spv
25 {
26 #include <spirv/unified1/GLSL.std.450.h>
27 }
28 
29 // SPIR-V tools include for disassembly
30 #include <spirv-tools/libspirv.hpp>
31 
32 // Enable this for debug logging of pre-transform SPIR-V:
33 #if !defined(ANGLE_DEBUG_SPIRV_GENERATION)
34 #    define ANGLE_DEBUG_SPIRV_GENERATION 0
35 #endif  // !defined(ANGLE_DEBUG_SPIRV_GENERATION)
36 
37 namespace sh
38 {
39 namespace
40 {
41 // A struct to hold either SPIR-V ids or literal constants.   If id is not valid, a literal is
42 // assumed.
43 struct SpirvIdOrLiteral
44 {
45     SpirvIdOrLiteral() = default;
SpirvIdOrLiteralsh::__anonce4844100111::SpirvIdOrLiteral46     SpirvIdOrLiteral(const spirv::IdRef idIn) : id(idIn) {}
SpirvIdOrLiteralsh::__anonce4844100111::SpirvIdOrLiteral47     SpirvIdOrLiteral(const spirv::LiteralInteger literalIn) : literal(literalIn) {}
48 
49     spirv::IdRef id;
50     spirv::LiteralInteger literal;
51 };
52 
53 // A data structure to facilitate generating array indexing, block field selection, swizzle and
54 // such.  Used in conjunction with NodeData which includes the access chain's baseId and idList.
55 //
56 // - rvalue[literal].field[literal] generates OpCompositeExtract
57 // - rvalue.x generates OpCompositeExtract
58 // - rvalue.xyz generates OpVectorShuffle
59 // - rvalue.xyz[i] generates OpVectorExtractDynamic (xyz[i] itself generates an
60 //   OpVectorExtractDynamic as well)
61 // - rvalue[i].field[j] generates a temp variable OpStore'ing rvalue and then generating an
62 //   OpAccessChain and OpLoad
63 //
64 // - lvalue[i].field[j].x generates OpAccessChain and OpStore
65 // - lvalue.xyz generates an OpLoad followed by OpVectorShuffle and OpStore
66 // - lvalue.xyz[i] generates OpAccessChain and OpStore (xyz[i] itself generates an
67 //   OpVectorExtractDynamic as well)
68 //
69 // storageClass == Max implies an rvalue.
70 //
71 struct AccessChain
72 {
73     // The storage class for lvalues.  If Max, it's an rvalue.
74     spv::StorageClass storageClass = spv::StorageClassMax;
75     // If the access chain ends in swizzle, the swizzle components are specified here.  Swizzles
76     // select multiple components so need special treatment when used as lvalue.
77     std::vector<uint32_t> swizzles;
78     // If a vector component is selected dynamically (i.e. indexed with a non-literal index),
79     // dynamicComponent will contain the id of the index.
80     spirv::IdRef dynamicComponent;
81 
82     // Type of base expression, before swizzle is applied, after swizzle is applied and after
83     // dynamic component is applied.
84     spirv::IdRef baseTypeId;
85     spirv::IdRef preSwizzleTypeId;
86     spirv::IdRef postSwizzleTypeId;
87     spirv::IdRef postDynamicComponentTypeId;
88 
89     // If the OpAccessChain is already generated (done by accessChainCollapse()), this caches the
90     // id.
91     spirv::IdRef accessChainId;
92 
93     // Whether all indices are literal.  Avoids looping through indices to determine this
94     // information.
95     bool areAllIndicesLiteral = true;
96     // The number of components in the vector, if vector and swizzle is used.  This is cached to
97     // avoid a type look up when handling swizzles.
98     uint8_t swizzledVectorComponentCount = 0;
99 
100     // SPIR-V type specialization due to the base type.  Used to correctly select the SPIR-V type
101     // id when visiting EOpIndex* binary nodes (i.e. reading from or writing to an access chain).
102     // This always corresponds to the specialization specific to the end result of the access chain,
103     // not the base or any intermediary types.  For example, a struct nested in a column-major
104     // interface block, with a parent block qualified as row-major would specify row-major here.
105     SpirvTypeSpec typeSpec;
106 };
107 
108 // As each node is traversed, it produces data.  When visiting back the parent, this data is used to
109 // complete the data of the parent.  For example, the children of a function call (i.e. the
110 // arguments) each produce a SPIR-V id corresponding to the result of their expression.  The
111 // function call node itself in PostVisit uses those ids to generate the function call instruction.
112 struct NodeData
113 {
114     // An id whose meaning depends on the node.  It could be a temporary id holding the result of an
115     // expression, a reference to a variable etc.
116     spirv::IdRef baseId;
117 
118     // List of relevant SPIR-V ids accumulated while traversing the children.  Meaning depends on
119     // the node, for example a list of parameters to be passed to a function, a set of ids used to
120     // construct an access chain etc.
121     std::vector<SpirvIdOrLiteral> idList;
122 
123     // For constructing access chains.
124     AccessChain accessChain;
125 };
126 
127 struct FunctionIds
128 {
129     // Id of the function type, return type and parameter types.
130     spirv::IdRef functionTypeId;
131     spirv::IdRef returnTypeId;
132     spirv::IdRefList parameterTypeIds;
133 
134     // Id of the function itself.
135     spirv::IdRef functionId;
136 };
137 
138 struct BuiltInResultStruct
139 {
140     // Some builtins require a struct result.  The struct always has two fields of a scalar or
141     // vector type.
142     TBasicType lsbType;
143     TBasicType msbType;
144     uint32_t lsbPrimarySize;
145     uint32_t msbPrimarySize;
146 };
147 
148 struct BuiltInResultStructHash
149 {
operator ()sh::__anonce4844100111::BuiltInResultStructHash150     size_t operator()(const BuiltInResultStruct &key) const
151     {
152         static_assert(sh::EbtLast < 256, "Basic type doesn't fit in uint8_t");
153         ASSERT(key.lsbPrimarySize > 0 && key.lsbPrimarySize <= 4);
154         ASSERT(key.msbPrimarySize > 0 && key.msbPrimarySize <= 4);
155 
156         const uint8_t properties[4] = {
157             static_cast<uint8_t>(key.lsbType),
158             static_cast<uint8_t>(key.msbType),
159             static_cast<uint8_t>(key.lsbPrimarySize),
160             static_cast<uint8_t>(key.msbPrimarySize),
161         };
162 
163         return angle::ComputeGenericHash(properties, sizeof(properties));
164     }
165 };
166 
operator ==(const BuiltInResultStruct & a,const BuiltInResultStruct & b)167 bool operator==(const BuiltInResultStruct &a, const BuiltInResultStruct &b)
168 {
169     return a.lsbType == b.lsbType && a.msbType == b.msbType &&
170            a.lsbPrimarySize == b.lsbPrimarySize && a.msbPrimarySize == b.msbPrimarySize;
171 }
172 
IsAccessChainRValue(const AccessChain & accessChain)173 bool IsAccessChainRValue(const AccessChain &accessChain)
174 {
175     return accessChain.storageClass == spv::StorageClassMax;
176 }
177 
178 // A traverser that generates SPIR-V as it walks the AST.
179 class OutputSPIRVTraverser : public TIntermTraverser
180 {
181   public:
182     OutputSPIRVTraverser(TCompiler *compiler,
183                          const ShCompileOptions &compileOptions,
184                          const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
185                          uint32_t firstUnusedSpirvId);
186     ~OutputSPIRVTraverser() override;
187 
188     spirv::Blob getSpirv();
189 
190   protected:
191     void visitSymbol(TIntermSymbol *node) override;
192     void visitConstantUnion(TIntermConstantUnion *node) override;
193     bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
194     bool visitBinary(Visit visit, TIntermBinary *node) override;
195     bool visitUnary(Visit visit, TIntermUnary *node) override;
196     bool visitTernary(Visit visit, TIntermTernary *node) override;
197     bool visitIfElse(Visit visit, TIntermIfElse *node) override;
198     bool visitSwitch(Visit visit, TIntermSwitch *node) override;
199     bool visitCase(Visit visit, TIntermCase *node) override;
200     void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
201     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
202     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
203     bool visitBlock(Visit visit, TIntermBlock *node) override;
204     bool visitGlobalQualifierDeclaration(Visit visit,
205                                          TIntermGlobalQualifierDeclaration *node) override;
206     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
207     bool visitLoop(Visit visit, TIntermLoop *node) override;
208     bool visitBranch(Visit visit, TIntermBranch *node) override;
209     void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
210 
211   private:
212     spirv::IdRef getSymbolIdAndStorageClass(const TSymbol *symbol,
213                                             const TType &type,
214                                             spv::StorageClass *storageClass);
215 
216     // Access chain handling.
217 
218     // Called before pushing indices to access chain to adjust |typeSpec| (which is then used to
219     // determine the typeId passed to |accessChainPush*|).
220     void accessChainOnPush(NodeData *data, const TType &parentType, size_t index);
221     void accessChainPush(NodeData *data, spirv::IdRef index, spirv::IdRef typeId) const;
222     void accessChainPushLiteral(NodeData *data,
223                                 spirv::LiteralInteger index,
224                                 spirv::IdRef typeId) const;
225     void accessChainPushSwizzle(NodeData *data,
226                                 const TVector<int> &swizzle,
227                                 spirv::IdRef typeId,
228                                 uint8_t componentCount) const;
229     void accessChainPushDynamicComponent(NodeData *data, spirv::IdRef index, spirv::IdRef typeId);
230     spirv::IdRef accessChainCollapse(NodeData *data);
231     spirv::IdRef accessChainLoad(NodeData *data,
232                                  const TType &valueType,
233                                  spirv::IdRef *resultTypeIdOut);
234     void accessChainStore(NodeData *data, spirv::IdRef value, const TType &valueType);
235 
236     // Access chain helpers.
237     void makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut);
238     void makeAccessChainLiteralList(NodeData *data, spirv::LiteralIntegerList *literalsOut);
239     spirv::IdRef getAccessChainTypeId(NodeData *data);
240 
241     // Node data handling.
242     void nodeDataInitLValue(NodeData *data,
243                             spirv::IdRef baseId,
244                             spirv::IdRef typeId,
245                             spv::StorageClass storageClass,
246                             const SpirvTypeSpec &typeSpec) const;
247     void nodeDataInitRValue(NodeData *data, spirv::IdRef baseId, spirv::IdRef typeId) const;
248 
249     void declareConst(TIntermDeclaration *decl);
250     void declareSpecConst(TIntermDeclaration *decl);
251     spirv::IdRef createConstant(const TType &type,
252                                 TBasicType expectedBasicType,
253                                 const TConstantUnion *constUnion,
254                                 bool isConstantNullValue);
255     spirv::IdRef createComplexConstant(const TType &type,
256                                        spirv::IdRef typeId,
257                                        const spirv::IdRefList &parameters);
258     spirv::IdRef createConstructor(TIntermAggregate *node, spirv::IdRef typeId);
259     spirv::IdRef createArrayOrStructConstructor(TIntermAggregate *node,
260                                                 spirv::IdRef typeId,
261                                                 const spirv::IdRefList &parameters);
262     spirv::IdRef createConstructorScalarFromNonScalar(TIntermAggregate *node,
263                                                       spirv::IdRef typeId,
264                                                       const spirv::IdRefList &parameters);
265     spirv::IdRef createConstructorVectorFromScalar(const TType &parameterType,
266                                                    const TType &expectedType,
267                                                    spirv::IdRef typeId,
268                                                    const spirv::IdRefList &parameters);
269     spirv::IdRef createConstructorVectorFromMatrix(TIntermAggregate *node,
270                                                    spirv::IdRef typeId,
271                                                    const spirv::IdRefList &parameters);
272     spirv::IdRef createConstructorVectorFromMultiple(TIntermAggregate *node,
273                                                      spirv::IdRef typeId,
274                                                      const spirv::IdRefList &parameters);
275     spirv::IdRef createConstructorMatrixFromScalar(TIntermAggregate *node,
276                                                    spirv::IdRef typeId,
277                                                    const spirv::IdRefList &parameters);
278     spirv::IdRef createConstructorMatrixFromVectors(TIntermAggregate *node,
279                                                     spirv::IdRef typeId,
280                                                     const spirv::IdRefList &parameters);
281     spirv::IdRef createConstructorMatrixFromMatrix(TIntermAggregate *node,
282                                                    spirv::IdRef typeId,
283                                                    const spirv::IdRefList &parameters);
284     // Load N values where N is the number of node's children.  In some cases, the last M values are
285     // lvalues which should be skipped.
286     spirv::IdRefList loadAllParams(TIntermOperator *node,
287                                    size_t skipCount,
288                                    spirv::IdRefList *paramTypeIds);
289     void extractComponents(TIntermAggregate *node,
290                            size_t componentCount,
291                            const spirv::IdRefList &parameters,
292                            spirv::IdRefList *extractedComponentsOut);
293 
294     void startShortCircuit(TIntermBinary *node);
295     spirv::IdRef endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId);
296 
297     spirv::IdRef visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId);
298     spirv::IdRef createCompare(TIntermOperator *node, spirv::IdRef resultTypeId);
299     spirv::IdRef createAtomicBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
300     spirv::IdRef createImageTextureBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
301     spirv::IdRef createSubpassLoadBuiltIn(TIntermOperator *node, spirv::IdRef resultTypeId);
302     spirv::IdRef createInterpolate(TIntermOperator *node, spirv::IdRef resultTypeId);
303 
304     spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
305 
306     void visitArrayLength(TIntermUnary *node);
307 
308     // Cast between types.  There are two kinds of casts:
309     //
310     // - A constructor can cast between basic types, for example vec4(someInt).
311     // - Assignments, constructors, function calls etc may copy an array or struct between different
312     //   block storages, invariance etc (which due to their decorations generate different SPIR-V
313     //   types).  For example:
314     //
315     //       layout(std140) uniform U { invariant Struct s; } u; ... Struct s2 = u.s;
316     //
317     spirv::IdRef castBasicType(spirv::IdRef value,
318                                const TType &valueType,
319                                const TType &expectedType,
320                                spirv::IdRef *resultTypeIdOut);
321     spirv::IdRef cast(spirv::IdRef value,
322                       const TType &valueType,
323                       const SpirvTypeSpec &valueTypeSpec,
324                       const SpirvTypeSpec &expectedTypeSpec,
325                       spirv::IdRef *resultTypeIdOut);
326 
327     // Given a list of parameters to an operator, extend the scalars to match the vectors.  GLSL
328     // frequently has operators that mix vectors and scalars, while SPIR-V usually applies the
329     // operations per component (requiring the scalars to turn into a vector).
330     void extendScalarParamsToVector(TIntermOperator *node,
331                                     spirv::IdRef resultTypeId,
332                                     spirv::IdRefList *parameters);
333 
334     // Helper to reduce vector == and != with OpAll and OpAny respectively.  If multiple ids are
335     // given, either OpLogicalAnd or OpLogicalOr is used (if two operands) or a bool vector is
336     // constructed and OpAll and OpAny used.
337     spirv::IdRef reduceBoolVector(TOperator op,
338                                   const spirv::IdRefList &valueIds,
339                                   spirv::IdRef typeId,
340                                   const SpirvDecorations &decorations);
341     // Helper to implement == and !=, supporting vectors, matrices, structs and arrays.
342     void createCompareImpl(TOperator op,
343                            const TType &operandType,
344                            spirv::IdRef resultTypeId,
345                            spirv::IdRef leftId,
346                            spirv::IdRef rightId,
347                            const SpirvDecorations &operandDecorations,
348                            const SpirvDecorations &resultDecorations,
349                            spirv::LiteralIntegerList *currentAccessChain,
350                            spirv::IdRefList *intermediateResultsOut);
351 
352     // For some builtins, SPIR-V outputs two values in a struct.  This function defines such a
353     // struct if not already defined.
354     spirv::IdRef makeBuiltInOutputStructType(TIntermOperator *node, size_t lvalueCount);
355     // Once the builtin instruction is generated, the two return values are extracted from the
356     // struct.  These are written to the return value (if any) and the out parameters.
357     void storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator *node,
358                                                         size_t lvalueCount,
359                                                         spirv::IdRef structValue,
360                                                         spirv::IdRef returnValue,
361                                                         spirv::IdRef returnValueType);
362     void storeBuiltInStructOutputInParamHelper(NodeData *data,
363                                                TIntermTyped *param,
364                                                spirv::IdRef structValue,
365                                                uint32_t fieldIndex);
366 
367     void markVertexOutputOnShaderEnd();
368     void markVertexOutputOnEmitVertex();
369 
370     TCompiler *mCompiler;
371     ANGLE_MAYBE_UNUSED_PRIVATE_FIELD const ShCompileOptions &mCompileOptions;
372 
373     SPIRVBuilder mBuilder;
374 
375     // Traversal state.  Nodes generally push() once to this stack on PreVisit.  On InVisit and
376     // PostVisit, they pop() once (data corresponding to the result of the child) and accumulate it
377     // in back() (data corresponding to the node itself).  On PostVisit, code is generated.
378     std::vector<NodeData> mNodeData;
379 
380     // A map of TSymbol to its SPIR-V id.  This could be a:
381     //
382     // - TVariable, or
383     // - TInterfaceBlock: because TIntermSymbols referencing a field of an unnamed interface block
384     //   don't reference the TVariable that defines the struct, but the TInterfaceBlock itself.
385     angle::HashMap<const TSymbol *, spirv::IdRef> mSymbolIdMap;
386 
387     // A map of TFunction to its various SPIR-V ids.
388     angle::HashMap<const TFunction *, FunctionIds> mFunctionIdMap;
389 
390     // A map of internally defined structs used to capture result of some SPIR-V instructions.
391     angle::HashMap<BuiltInResultStruct, spirv::IdRef, BuiltInResultStructHash>
392         mBuiltInResultStructMap;
393 
394     // Whether the current symbol being visited is being declared.
395     bool mIsSymbolBeingDeclared = false;
396 
397     // What is the id of the current function being generated.
398     spirv::IdRef mCurrentFunctionId;
399 };
400 
GetStorageClass(const ShCompileOptions & compileOptions,const TType & type,GLenum shaderType)401 spv::StorageClass GetStorageClass(const ShCompileOptions &compileOptions,
402                                   const TType &type,
403                                   GLenum shaderType)
404 {
405     // Opaque uniforms (samplers, images and subpass inputs) have the UniformConstant storage class
406     if (IsOpaqueType(type.getBasicType()))
407     {
408         return spv::StorageClassUniformConstant;
409     }
410 
411     const TQualifier qualifier = type.getQualifier();
412 
413     // Input varying and IO blocks have the Input storage class
414     if (IsShaderIn(qualifier))
415     {
416         return spv::StorageClassInput;
417     }
418 
419     // Output varying and IO blocks have the Input storage class
420     if (IsShaderOut(qualifier))
421     {
422         return spv::StorageClassOutput;
423     }
424 
425     switch (qualifier)
426     {
427         case EvqShared:
428             // Compute shader shared memory has the Workgroup storage class
429             return spv::StorageClassWorkgroup;
430 
431         case EvqGlobal:
432         case EvqConst:
433             // Global variables have the Private class.  Complex constant variables that are not
434             // folded are also defined globally.
435             return spv::StorageClassPrivate;
436 
437         case EvqTemporary:
438         case EvqParamIn:
439         case EvqParamOut:
440         case EvqParamInOut:
441             // Function-local variables have the Function class
442             return spv::StorageClassFunction;
443 
444         case EvqVertexID:
445         case EvqInstanceID:
446         case EvqFragCoord:
447         case EvqFrontFacing:
448         case EvqPointCoord:
449         case EvqSampleID:
450         case EvqSamplePosition:
451         case EvqSampleMaskIn:
452         case EvqPatchVerticesIn:
453         case EvqTessCoord:
454         case EvqPrimitiveIDIn:
455         case EvqInvocationID:
456         case EvqHelperInvocation:
457         case EvqNumWorkGroups:
458         case EvqWorkGroupID:
459         case EvqLocalInvocationID:
460         case EvqGlobalInvocationID:
461         case EvqLocalInvocationIndex:
462         case EvqViewIDOVR:
463         case EvqLayerIn:
464         case EvqLastFragColor:
465         case EvqLastFragDepth:
466         case EvqLastFragStencil:
467             return spv::StorageClassInput;
468 
469         case EvqPosition:
470         case EvqPointSize:
471         case EvqFragDepth:
472         case EvqSampleMask:
473         case EvqLayerOut:
474             return spv::StorageClassOutput;
475 
476         case EvqClipDistance:
477         case EvqCullDistance:
478             // gl_Clip/CullDistance (not accessed through gl_in/gl_out) are inputs in FS and outputs
479             // otherwise.
480             return shaderType == GL_FRAGMENT_SHADER ? spv::StorageClassInput
481                                                     : spv::StorageClassOutput;
482 
483         case EvqTessLevelOuter:
484         case EvqTessLevelInner:
485             // gl_TessLevelOuter/Inner are outputs in TCS and inputs in TES.
486             return shaderType == GL_TESS_CONTROL_SHADER_EXT ? spv::StorageClassOutput
487                                                             : spv::StorageClassInput;
488 
489         case EvqPrimitiveID:
490             // gl_PrimitiveID is output in GS and input in TCS, TES and FS.
491             return shaderType == GL_GEOMETRY_SHADER ? spv::StorageClassOutput
492                                                     : spv::StorageClassInput;
493 
494         default:
495             // Uniform buffers have the Uniform storage class.  Storage buffers have the Uniform
496             // storage class in SPIR-V 1.3, and the StorageBuffer storage class in SPIR-V 1.4.
497             // Default uniforms are gathered in a uniform block as well.  Push constants use the
498             // PushConstant storage class instead.
499             ASSERT(type.getInterfaceBlock() != nullptr || qualifier == EvqUniform);
500             // I/O blocks must have already been classified as input or output above.
501             ASSERT(!IsShaderIoBlock(qualifier));
502 
503             if (type.getLayoutQualifier().pushConstant)
504             {
505                 ASSERT(type.getInterfaceBlock() != nullptr);
506                 return spv::StorageClassPushConstant;
507             }
508             return compileOptions.emitSPIRV14 && qualifier == EvqBuffer
509                        ? spv::StorageClassStorageBuffer
510                        : spv::StorageClassUniform;
511     }
512 }
513 
OutputSPIRVTraverser(TCompiler * compiler,const ShCompileOptions & compileOptions,const angle::HashMap<int,uint32_t> & uniqueToSpirvIdMap,uint32_t firstUnusedSpirvId)514 OutputSPIRVTraverser::OutputSPIRVTraverser(TCompiler *compiler,
515                                            const ShCompileOptions &compileOptions,
516                                            const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
517                                            uint32_t firstUnusedSpirvId)
518     : TIntermTraverser(true, true, true, &compiler->getSymbolTable()),
519       mCompiler(compiler),
520       mCompileOptions(compileOptions),
521       mBuilder(compiler, compileOptions, uniqueToSpirvIdMap, firstUnusedSpirvId)
522 {}
523 
~OutputSPIRVTraverser()524 OutputSPIRVTraverser::~OutputSPIRVTraverser()
525 {
526     ASSERT(mNodeData.empty());
527 }
528 
getSymbolIdAndStorageClass(const TSymbol * symbol,const TType & type,spv::StorageClass * storageClass)529 spirv::IdRef OutputSPIRVTraverser::getSymbolIdAndStorageClass(const TSymbol *symbol,
530                                                               const TType &type,
531                                                               spv::StorageClass *storageClass)
532 {
533     *storageClass = GetStorageClass(mCompileOptions, type, mCompiler->getShaderType());
534     auto iter     = mSymbolIdMap.find(symbol);
535     if (iter != mSymbolIdMap.end())
536     {
537         return iter->second;
538     }
539 
540     // This must be an implicitly defined variable, define it now.
541     const char *name                = nullptr;
542     spv::BuiltIn builtInDecoration  = spv::BuiltInMax;
543     const TSymbolUniqueId *uniqueId = nullptr;
544 
545     switch (type.getQualifier())
546     {
547         // Vertex shader built-ins
548         case EvqVertexID:
549             name              = "gl_VertexIndex";
550             builtInDecoration = spv::BuiltInVertexIndex;
551             break;
552         case EvqInstanceID:
553             name              = "gl_InstanceIndex";
554             builtInDecoration = spv::BuiltInInstanceIndex;
555             break;
556 
557         // Fragment shader built-ins
558         case EvqFragCoord:
559             name              = "gl_FragCoord";
560             builtInDecoration = spv::BuiltInFragCoord;
561             break;
562         case EvqFrontFacing:
563             name              = "gl_FrontFacing";
564             builtInDecoration = spv::BuiltInFrontFacing;
565             break;
566         case EvqPointCoord:
567             name              = "gl_PointCoord";
568             builtInDecoration = spv::BuiltInPointCoord;
569             break;
570         case EvqFragDepth:
571             name              = "gl_FragDepth";
572             builtInDecoration = spv::BuiltInFragDepth;
573             mBuilder.addExecutionMode(spv::ExecutionModeDepthReplacing);
574             switch (type.getLayoutQualifier().depth)
575             {
576                 case EdGreater:
577                     mBuilder.addExecutionMode(spv::ExecutionModeDepthGreater);
578                     break;
579                 case EdLess:
580                     mBuilder.addExecutionMode(spv::ExecutionModeDepthLess);
581                     break;
582                 case EdUnchanged:
583                     mBuilder.addExecutionMode(spv::ExecutionModeDepthUnchanged);
584                     break;
585                 default:
586                     break;
587             }
588             break;
589         case EvqSampleMask:
590             name              = "gl_SampleMask";
591             builtInDecoration = spv::BuiltInSampleMask;
592             break;
593         case EvqSampleMaskIn:
594             name              = "gl_SampleMaskIn";
595             builtInDecoration = spv::BuiltInSampleMask;
596             break;
597         case EvqSampleID:
598             name              = "gl_SampleID";
599             builtInDecoration = spv::BuiltInSampleId;
600             mBuilder.addCapability(spv::CapabilitySampleRateShading);
601             uniqueId = &symbol->uniqueId();
602             break;
603         case EvqSamplePosition:
604             name              = "gl_SamplePosition";
605             builtInDecoration = spv::BuiltInSamplePosition;
606             mBuilder.addCapability(spv::CapabilitySampleRateShading);
607             break;
608         case EvqClipDistance:
609             name              = "gl_ClipDistance";
610             builtInDecoration = spv::BuiltInClipDistance;
611             mBuilder.addCapability(spv::CapabilityClipDistance);
612             break;
613         case EvqCullDistance:
614             name              = "gl_CullDistance";
615             builtInDecoration = spv::BuiltInCullDistance;
616             mBuilder.addCapability(spv::CapabilityCullDistance);
617             break;
618         case EvqHelperInvocation:
619             name              = "gl_HelperInvocation";
620             builtInDecoration = spv::BuiltInHelperInvocation;
621             break;
622 
623         // Tessellation built-ins
624         case EvqPatchVerticesIn:
625             name              = "gl_PatchVerticesIn";
626             builtInDecoration = spv::BuiltInPatchVertices;
627             break;
628         case EvqTessLevelOuter:
629             name              = "gl_TessLevelOuter";
630             builtInDecoration = spv::BuiltInTessLevelOuter;
631             break;
632         case EvqTessLevelInner:
633             name              = "gl_TessLevelInner";
634             builtInDecoration = spv::BuiltInTessLevelInner;
635             break;
636         case EvqTessCoord:
637             name              = "gl_TessCoord";
638             builtInDecoration = spv::BuiltInTessCoord;
639             break;
640 
641         // Shared geometry and tessellation built-ins
642         case EvqInvocationID:
643             name              = "gl_InvocationID";
644             builtInDecoration = spv::BuiltInInvocationId;
645             break;
646         case EvqPrimitiveID:
647             name              = "gl_PrimitiveID";
648             builtInDecoration = spv::BuiltInPrimitiveId;
649 
650             // In fragment shader, add the Geometry capability.
651             if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
652             {
653                 mBuilder.addCapability(spv::CapabilityGeometry);
654             }
655 
656             break;
657 
658         // Geometry shader built-ins
659         case EvqPrimitiveIDIn:
660             name              = "gl_PrimitiveIDIn";
661             builtInDecoration = spv::BuiltInPrimitiveId;
662             break;
663         case EvqLayerOut:
664         case EvqLayerIn:
665             name              = "gl_Layer";
666             builtInDecoration = spv::BuiltInLayer;
667 
668             // gl_Layer requires the Geometry capability, even in fragment shaders.
669             mBuilder.addCapability(spv::CapabilityGeometry);
670 
671             break;
672 
673         // Compute shader built-ins
674         case EvqNumWorkGroups:
675             name              = "gl_NumWorkGroups";
676             builtInDecoration = spv::BuiltInNumWorkgroups;
677             break;
678         case EvqWorkGroupID:
679             name              = "gl_WorkGroupID";
680             builtInDecoration = spv::BuiltInWorkgroupId;
681             break;
682         case EvqLocalInvocationID:
683             name              = "gl_LocalInvocationID";
684             builtInDecoration = spv::BuiltInLocalInvocationId;
685             break;
686         case EvqGlobalInvocationID:
687             name              = "gl_GlobalInvocationID";
688             builtInDecoration = spv::BuiltInGlobalInvocationId;
689             break;
690         case EvqLocalInvocationIndex:
691             name              = "gl_LocalInvocationIndex";
692             builtInDecoration = spv::BuiltInLocalInvocationIndex;
693             break;
694 
695         // Extensions
696         case EvqViewIDOVR:
697             name              = "gl_ViewID_OVR";
698             builtInDecoration = spv::BuiltInViewIndex;
699             mBuilder.addCapability(spv::CapabilityMultiView);
700             mBuilder.addExtension(SPIRVExtensions::MultiviewOVR);
701             break;
702 
703         default:
704             UNREACHABLE();
705     }
706 
707     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
708     const spirv::IdRef varId  = mBuilder.declareVariable(
709         typeId, *storageClass, mBuilder.getDecorations(type), nullptr, name, uniqueId);
710 
711     spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationBuiltIn,
712                          {spirv::LiteralInteger(builtInDecoration)});
713 
714     // Additionally:
715     //
716     // - decorate int inputs in FS with Flat (gl_Layer, gl_SampleID, gl_PrimitiveID, gl_ViewID_OVR).
717     // - decorate gl_TessLevel* with Patch.
718     switch (type.getQualifier())
719     {
720         case EvqLayerIn:
721         case EvqSampleID:
722         case EvqPrimitiveID:
723         case EvqViewIDOVR:
724             if (mCompiler->getShaderType() == GL_FRAGMENT_SHADER)
725             {
726                 spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationFlat,
727                                      {});
728             }
729             break;
730         case EvqTessLevelInner:
731         case EvqTessLevelOuter:
732             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), varId, spv::DecorationPatch, {});
733             break;
734         default:
735             break;
736     }
737 
738     mSymbolIdMap.insert({symbol, varId});
739     return varId;
740 }
741 
nodeDataInitLValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId,spv::StorageClass storageClass,const SpirvTypeSpec & typeSpec) const742 void OutputSPIRVTraverser::nodeDataInitLValue(NodeData *data,
743                                               spirv::IdRef baseId,
744                                               spirv::IdRef typeId,
745                                               spv::StorageClass storageClass,
746                                               const SpirvTypeSpec &typeSpec) const
747 {
748     *data = {};
749 
750     // Initialize the access chain as an lvalue.  Useful when an access chain is resolved, but needs
751     // to be replaced by a reference to a temporary variable holding the result.
752     data->baseId                       = baseId;
753     data->accessChain.baseTypeId       = typeId;
754     data->accessChain.preSwizzleTypeId = typeId;
755     data->accessChain.storageClass     = storageClass;
756     data->accessChain.typeSpec         = typeSpec;
757 }
758 
nodeDataInitRValue(NodeData * data,spirv::IdRef baseId,spirv::IdRef typeId) const759 void OutputSPIRVTraverser::nodeDataInitRValue(NodeData *data,
760                                               spirv::IdRef baseId,
761                                               spirv::IdRef typeId) const
762 {
763     *data = {};
764 
765     // Initialize the access chain as an rvalue.  Useful when an access chain is resolved, and needs
766     // to be replaced by a reference to it.
767     data->baseId                       = baseId;
768     data->accessChain.baseTypeId       = typeId;
769     data->accessChain.preSwizzleTypeId = typeId;
770 }
771 
accessChainOnPush(NodeData * data,const TType & parentType,size_t index)772 void OutputSPIRVTraverser::accessChainOnPush(NodeData *data, const TType &parentType, size_t index)
773 {
774     AccessChain &accessChain = data->accessChain;
775 
776     // Adjust |typeSpec| based on the type (which implies what the index does; select an array
777     // element, a block field etc).  Index is only meaningful for selecting block fields.
778     if (parentType.isArray())
779     {
780         accessChain.typeSpec.onArrayElementSelection(
781             (parentType.getStruct() != nullptr || parentType.isInterfaceBlock()),
782             parentType.isArrayOfArrays());
783     }
784     else if (parentType.isInterfaceBlock() || parentType.getStruct() != nullptr)
785     {
786         const TFieldListCollection *block = parentType.getInterfaceBlock();
787         if (!parentType.isInterfaceBlock())
788         {
789             block = parentType.getStruct();
790         }
791 
792         const TType &fieldType = *block->fields()[index]->type();
793         accessChain.typeSpec.onBlockFieldSelection(fieldType);
794     }
795     else if (parentType.isMatrix())
796     {
797         accessChain.typeSpec.onMatrixColumnSelection();
798     }
799     else
800     {
801         ASSERT(parentType.isVector());
802         accessChain.typeSpec.onVectorComponentSelection();
803     }
804 }
805 
accessChainPush(NodeData * data,spirv::IdRef index,spirv::IdRef typeId) const806 void OutputSPIRVTraverser::accessChainPush(NodeData *data,
807                                            spirv::IdRef index,
808                                            spirv::IdRef typeId) const
809 {
810     // Simply add the index to the chain of indices.
811     data->idList.emplace_back(index);
812     data->accessChain.areAllIndicesLiteral = false;
813     data->accessChain.preSwizzleTypeId     = typeId;
814 }
815 
accessChainPushLiteral(NodeData * data,spirv::LiteralInteger index,spirv::IdRef typeId) const816 void OutputSPIRVTraverser::accessChainPushLiteral(NodeData *data,
817                                                   spirv::LiteralInteger index,
818                                                   spirv::IdRef typeId) const
819 {
820     // Add the literal integer in the chain of indices.  Since this is an id list, fake it as an id.
821     data->idList.emplace_back(index);
822     data->accessChain.preSwizzleTypeId = typeId;
823 
824     // Literal index of a swizzle must be already folded in the AST.
825     ASSERT(data->accessChain.swizzles.empty());
826 }
827 
accessChainPushSwizzle(NodeData * data,const TVector<int> & swizzle,spirv::IdRef typeId,uint8_t componentCount) const828 void OutputSPIRVTraverser::accessChainPushSwizzle(NodeData *data,
829                                                   const TVector<int> &swizzle,
830                                                   spirv::IdRef typeId,
831                                                   uint8_t componentCount) const
832 {
833     AccessChain &accessChain = data->accessChain;
834 
835     // Record the swizzle as multi-component swizzles require special handling.  When loading
836     // through the access chain, the swizzle is applied after loading the vector first (see
837     // |accessChainLoad()|).  When storing through the access chain, the whole vector is loaded,
838     // swizzled components overwritten and the whole vector written back (see |accessChainStore()|).
839     ASSERT(accessChain.swizzles.empty());
840 
841     if (swizzle.size() == 1)
842     {
843         // If this swizzle is selecting a single component, fold it into the access chain.
844         accessChainPushLiteral(data, spirv::LiteralInteger(swizzle[0]), typeId);
845     }
846     else
847     {
848         // Otherwise keep them separate.
849         accessChain.swizzles.insert(accessChain.swizzles.end(), swizzle.begin(), swizzle.end());
850         accessChain.postSwizzleTypeId            = typeId;
851         accessChain.swizzledVectorComponentCount = componentCount;
852     }
853 }
854 
accessChainPushDynamicComponent(NodeData * data,spirv::IdRef index,spirv::IdRef typeId)855 void OutputSPIRVTraverser::accessChainPushDynamicComponent(NodeData *data,
856                                                            spirv::IdRef index,
857                                                            spirv::IdRef typeId)
858 {
859     AccessChain &accessChain = data->accessChain;
860 
861     // Record the index used to dynamically select a component of a vector.
862     ASSERT(!accessChain.dynamicComponent.valid());
863 
864     if (IsAccessChainRValue(accessChain) && accessChain.areAllIndicesLiteral)
865     {
866         // If the access chain is an rvalue with all-literal indices, keep this index separate so
867         // that OpCompositeExtract can be used for the access chain up to this index.
868         accessChain.dynamicComponent           = index;
869         accessChain.postDynamicComponentTypeId = typeId;
870         return;
871     }
872 
873     if (!accessChain.swizzles.empty())
874     {
875         // Otherwise if there's a swizzle, fold the swizzle and dynamic component selection into a
876         // single dynamic component selection.
877         ASSERT(accessChain.swizzles.size() > 1);
878 
879         // Create a vector constant from the swizzles.
880         spirv::IdRefList swizzleIds;
881         for (uint32_t component : accessChain.swizzles)
882         {
883             swizzleIds.push_back(mBuilder.getUintConstant(component));
884         }
885 
886         const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
887         const spirv::IdRef uvecTypeId = mBuilder.getBasicTypeId(EbtUInt, swizzleIds.size());
888 
889         const spirv::IdRef swizzlesId = mBuilder.getNewId({});
890         spirv::WriteConstantComposite(mBuilder.getSpirvTypeAndConstantDecls(), uvecTypeId,
891                                       swizzlesId, swizzleIds);
892 
893         // Index that vector constant with the dynamic index.  For example, vec.ywxz[i] becomes the
894         // constant {1, 3, 0, 2} indexed with i, and that index used on vec.
895         const spirv::IdRef newIndex = mBuilder.getNewId({});
896         spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId,
897                                          newIndex, swizzlesId, index);
898 
899         index = newIndex;
900         accessChain.swizzles.clear();
901     }
902 
903     // Fold it into the access chain.
904     accessChainPush(data, index, typeId);
905 }
906 
accessChainCollapse(NodeData * data)907 spirv::IdRef OutputSPIRVTraverser::accessChainCollapse(NodeData *data)
908 {
909     AccessChain &accessChain = data->accessChain;
910 
911     ASSERT(accessChain.storageClass != spv::StorageClassMax);
912 
913     if (accessChain.accessChainId.valid())
914     {
915         return accessChain.accessChainId;
916     }
917 
918     // If there are no indices, the baseId is where access is done to/from.
919     if (data->idList.empty())
920     {
921         accessChain.accessChainId = data->baseId;
922         return accessChain.accessChainId;
923     }
924 
925     // Otherwise create an OpAccessChain instruction.  Swizzle handling is special as it selects
926     // multiple components, and is done differently for load and store.
927     spirv::IdRefList indexIds;
928     makeAccessChainIdList(data, &indexIds);
929 
930     const spirv::IdRef typePointerId =
931         mBuilder.getTypePointerId(accessChain.preSwizzleTypeId, accessChain.storageClass);
932 
933     accessChain.accessChainId = mBuilder.getNewId({});
934     spirv::WriteAccessChain(mBuilder.getSpirvCurrentFunctionBlock(), typePointerId,
935                             accessChain.accessChainId, data->baseId, indexIds);
936 
937     return accessChain.accessChainId;
938 }
939 
accessChainLoad(NodeData * data,const TType & valueType,spirv::IdRef * resultTypeIdOut)940 spirv::IdRef OutputSPIRVTraverser::accessChainLoad(NodeData *data,
941                                                    const TType &valueType,
942                                                    spirv::IdRef *resultTypeIdOut)
943 {
944     const SpirvDecorations &decorations = mBuilder.getDecorations(valueType);
945 
946     // Loading through the access chain can generate different instructions based on whether it's an
947     // rvalue, the indices are literal, there's a swizzle etc.
948     //
949     // - If rvalue:
950     //  * With indices:
951     //   + All literal: OpCompositeExtract which uses literal integers to access the rvalue.
952     //   + Otherwise: Can't use OpAccessChain on an rvalue, so create a temporary variable, OpStore
953     //     the rvalue into it, then use OpAccessChain and OpLoad to load from it.
954     //  * Without indices: Take the base id.
955     // - If lvalue:
956     //  * With indices: Use OpAccessChain and OpLoad
957     //  * Without indices: Use OpLoad
958     // - With swizzle: Use OpVectorShuffle on the result of the previous step
959     // - With dynamic component: Use OpVectorExtractDynamic on the result of the previous step
960 
961     AccessChain &accessChain = data->accessChain;
962 
963     if (resultTypeIdOut)
964     {
965         *resultTypeIdOut = getAccessChainTypeId(data);
966     }
967 
968     spirv::IdRef loadResult = data->baseId;
969 
970     if (IsAccessChainRValue(accessChain))
971     {
972         if (data->idList.size() > 0)
973         {
974             if (accessChain.areAllIndicesLiteral)
975             {
976                 // Use OpCompositeExtract on an rvalue with all literal indices.
977                 spirv::LiteralIntegerList indexList;
978                 makeAccessChainLiteralList(data, &indexList);
979 
980                 const spirv::IdRef result = mBuilder.getNewId(decorations);
981                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
982                                              accessChain.preSwizzleTypeId, result, loadResult,
983                                              indexList);
984                 loadResult = result;
985             }
986             else
987             {
988                 // Create a temp variable to hold the rvalue so an access chain can be made on it.
989                 const spirv::IdRef tempVar =
990                     mBuilder.declareVariable(accessChain.baseTypeId, spv::StorageClassFunction,
991                                              decorations, nullptr, "indexable", nullptr);
992 
993                 // Write the rvalue into the temp variable
994                 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVar, loadResult,
995                                   nullptr);
996 
997                 // Make the temp variable the source of the access chain.
998                 data->baseId                   = tempVar;
999                 data->accessChain.storageClass = spv::StorageClassFunction;
1000 
1001                 // Load from the temp variable.
1002                 const spirv::IdRef accessChainId = accessChainCollapse(data);
1003                 loadResult                       = mBuilder.getNewId(decorations);
1004                 spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(),
1005                                  accessChain.preSwizzleTypeId, loadResult, accessChainId, nullptr);
1006             }
1007         }
1008     }
1009     else
1010     {
1011         // Load from the access chain.
1012         const spirv::IdRef accessChainId = accessChainCollapse(data);
1013         loadResult                       = mBuilder.getNewId(decorations);
1014         spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1015                          loadResult, accessChainId, nullptr);
1016     }
1017 
1018     if (!accessChain.swizzles.empty())
1019     {
1020         // Single-component swizzles are already folded into the index list.
1021         ASSERT(accessChain.swizzles.size() > 1);
1022 
1023         // Take the loaded value and use OpVectorShuffle to create the swizzle.
1024         spirv::LiteralIntegerList swizzleList;
1025         for (uint32_t component : accessChain.swizzles)
1026         {
1027             swizzleList.push_back(spirv::LiteralInteger(component));
1028         }
1029 
1030         const spirv::IdRef result = mBuilder.getNewId(decorations);
1031         spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1032                                   accessChain.postSwizzleTypeId, result, loadResult, loadResult,
1033                                   swizzleList);
1034         loadResult = result;
1035     }
1036 
1037     if (accessChain.dynamicComponent.valid())
1038     {
1039         // Use OpVectorExtractDynamic to select the component.
1040         const spirv::IdRef result = mBuilder.getNewId(decorations);
1041         spirv::WriteVectorExtractDynamic(mBuilder.getSpirvCurrentFunctionBlock(),
1042                                          accessChain.postDynamicComponentTypeId, result, loadResult,
1043                                          accessChain.dynamicComponent);
1044         loadResult = result;
1045     }
1046 
1047     // Upon loading values, cast them to the default SPIR-V variant.
1048     const spirv::IdRef castResult =
1049         cast(loadResult, valueType, accessChain.typeSpec, {}, resultTypeIdOut);
1050 
1051     return castResult;
1052 }
1053 
accessChainStore(NodeData * data,spirv::IdRef value,const TType & valueType)1054 void OutputSPIRVTraverser::accessChainStore(NodeData *data,
1055                                             spirv::IdRef value,
1056                                             const TType &valueType)
1057 {
1058     // Storing through the access chain can generate different instructions based on whether the
1059     // there's a swizzle.
1060     //
1061     // - Without swizzle: Use OpAccessChain and OpStore
1062     // - With swizzle: Use OpAccessChain and OpLoad to load the vector, then use OpVectorShuffle to
1063     //   replace the components being overwritten.  Finally, use OpStore to write the result back.
1064 
1065     AccessChain &accessChain = data->accessChain;
1066 
1067     // Single-component swizzles are already folded into the indices.
1068     ASSERT(accessChain.swizzles.size() != 1);
1069     // Since store can only happen through lvalues, it's impossible to have a dynamic component as
1070     // that always gets folded into the indices except for rvalues.
1071     ASSERT(!accessChain.dynamicComponent.valid());
1072 
1073     const spirv::IdRef accessChainId = accessChainCollapse(data);
1074 
1075     // Store through the access chain.  The values are always cast to the default SPIR-V type
1076     // variant when loaded from memory and operated on as such.  When storing, we need to cast the
1077     // result to the variant specified by the access chain.
1078     value = cast(value, valueType, {}, accessChain.typeSpec, nullptr);
1079 
1080     if (!accessChain.swizzles.empty())
1081     {
1082         // Load the vector before the swizzle.
1083         const spirv::IdRef loadResult = mBuilder.getNewId({});
1084         spirv::WriteLoad(mBuilder.getSpirvCurrentFunctionBlock(), accessChain.preSwizzleTypeId,
1085                          loadResult, accessChainId, nullptr);
1086 
1087         // Overwrite the components being written.  This is done by first creating an identity
1088         // swizzle, then replacing the components being written with a swizzle from the value.  For
1089         // example, take the following:
1090         //
1091         //     vec4 v;
1092         //     v.zx = u;
1093         //
1094         // The OpVectorShuffle instruction takes two vectors (v and u) and selects components from
1095         // each (in this example, swizzles [0, 3] select from v and [4, 7] select from u).  This
1096         // algorithm first creates the identity swizzles {0, 1, 2, 3}, then replaces z and x (the
1097         // 0th and 2nd element) with swizzles from u (4 + {0, 1}) to get the result
1098         // {4+1, 1, 4+0, 3}.
1099 
1100         spirv::LiteralIntegerList swizzleList;
1101         for (uint32_t component = 0; component < accessChain.swizzledVectorComponentCount;
1102              ++component)
1103         {
1104             swizzleList.push_back(spirv::LiteralInteger(component));
1105         }
1106         uint32_t srcComponent = 0;
1107         for (uint32_t dstComponent : accessChain.swizzles)
1108         {
1109             swizzleList[dstComponent] =
1110                 spirv::LiteralInteger(accessChain.swizzledVectorComponentCount + srcComponent);
1111             ++srcComponent;
1112         }
1113 
1114         // Use the generated swizzle to select components from the loaded vector and the value to be
1115         // written.  Use the final result as the value to be written to the vector.
1116         const spirv::IdRef result = mBuilder.getNewId({});
1117         spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
1118                                   accessChain.preSwizzleTypeId, result, loadResult, value,
1119                                   swizzleList);
1120         value = result;
1121     }
1122 
1123     spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), accessChainId, value, nullptr);
1124 }
1125 
makeAccessChainIdList(NodeData * data,spirv::IdRefList * idsOut)1126 void OutputSPIRVTraverser::makeAccessChainIdList(NodeData *data, spirv::IdRefList *idsOut)
1127 {
1128     for (size_t index = 0; index < data->idList.size(); ++index)
1129     {
1130         spirv::IdRef indexId = data->idList[index].id;
1131 
1132         if (!indexId.valid())
1133         {
1134             // The index is a literal integer, so replace it with an OpConstant id.
1135             indexId = mBuilder.getUintConstant(data->idList[index].literal);
1136         }
1137 
1138         idsOut->push_back(indexId);
1139     }
1140 }
1141 
makeAccessChainLiteralList(NodeData * data,spirv::LiteralIntegerList * literalsOut)1142 void OutputSPIRVTraverser::makeAccessChainLiteralList(NodeData *data,
1143                                                       spirv::LiteralIntegerList *literalsOut)
1144 {
1145     for (size_t index = 0; index < data->idList.size(); ++index)
1146     {
1147         ASSERT(!data->idList[index].id.valid());
1148         literalsOut->push_back(data->idList[index].literal);
1149     }
1150 }
1151 
getAccessChainTypeId(NodeData * data)1152 spirv::IdRef OutputSPIRVTraverser::getAccessChainTypeId(NodeData *data)
1153 {
1154     // Load and store through the access chain may be done in multiple steps.  These steps produce
1155     // the following types:
1156     //
1157     // - preSwizzleTypeId
1158     // - postSwizzleTypeId
1159     // - postDynamicComponentTypeId
1160     //
1161     // The last of these types is the final type of the expression this access chain corresponds to.
1162     const AccessChain &accessChain = data->accessChain;
1163 
1164     if (accessChain.postDynamicComponentTypeId.valid())
1165     {
1166         return accessChain.postDynamicComponentTypeId;
1167     }
1168     if (accessChain.postSwizzleTypeId.valid())
1169     {
1170         return accessChain.postSwizzleTypeId;
1171     }
1172     ASSERT(accessChain.preSwizzleTypeId.valid());
1173     return accessChain.preSwizzleTypeId;
1174 }
1175 
declareConst(TIntermDeclaration * decl)1176 void OutputSPIRVTraverser::declareConst(TIntermDeclaration *decl)
1177 {
1178     const TIntermSequence &sequence = *decl->getSequence();
1179     ASSERT(sequence.size() == 1);
1180 
1181     TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1182     ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1183 
1184     TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1185     ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqConst);
1186 
1187     TIntermTyped *initializer = assign->getRight();
1188     ASSERT(initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
1189 
1190     const TType &type         = symbol->getType();
1191     const TVariable *variable = &symbol->variable();
1192 
1193     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1194     const spirv::IdRef constId =
1195         createConstant(type, type.getBasicType(), initializer->getConstantValue(),
1196                        initializer->isConstantNullValue());
1197 
1198     // Remember the id of the variable for future look up.
1199     ASSERT(mSymbolIdMap.count(variable) == 0);
1200     mSymbolIdMap[variable] = constId;
1201 
1202     if (!mInGlobalScope)
1203     {
1204         mNodeData.emplace_back();
1205         nodeDataInitRValue(&mNodeData.back(), constId, typeId);
1206     }
1207 }
1208 
declareSpecConst(TIntermDeclaration * decl)1209 void OutputSPIRVTraverser::declareSpecConst(TIntermDeclaration *decl)
1210 {
1211     const TIntermSequence &sequence = *decl->getSequence();
1212     ASSERT(sequence.size() == 1);
1213 
1214     TIntermBinary *assign = sequence.front()->getAsBinaryNode();
1215     ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
1216 
1217     TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1218     ASSERT(symbol != nullptr && symbol->getType().getQualifier() == EvqSpecConst);
1219 
1220     TIntermConstantUnion *initializer = assign->getRight()->getAsConstantUnion();
1221     ASSERT(initializer != nullptr);
1222 
1223     const TType &type         = symbol->getType();
1224     const TVariable *variable = &symbol->variable();
1225 
1226     // All spec consts in ANGLE are initialized to 0.
1227     ASSERT(initializer->isZero(0));
1228 
1229     const spirv::IdRef specConstId = mBuilder.declareSpecConst(
1230         type.getBasicType(), type.getLayoutQualifier().location, mBuilder.getName(variable).data());
1231 
1232     // Remember the id of the variable for future look up.
1233     ASSERT(mSymbolIdMap.count(variable) == 0);
1234     mSymbolIdMap[variable] = specConstId;
1235 }
1236 
createConstant(const TType & type,TBasicType expectedBasicType,const TConstantUnion * constUnion,bool isConstantNullValue)1237 spirv::IdRef OutputSPIRVTraverser::createConstant(const TType &type,
1238                                                   TBasicType expectedBasicType,
1239                                                   const TConstantUnion *constUnion,
1240                                                   bool isConstantNullValue)
1241 {
1242     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
1243     spirv::IdRefList componentIds;
1244 
1245     // If the object is all zeros, use OpConstantNull to avoid creating a bunch of constants.  This
1246     // is not done for basic scalar types as some instructions require an OpConstant and validation
1247     // doesn't accept OpConstantNull (likely a spec bug).
1248     const size_t size            = type.getObjectSize();
1249     const TBasicType basicType   = type.getBasicType();
1250     const bool isBasicScalar     = size == 1 && (basicType == EbtFloat || basicType == EbtInt ||
1251                                              basicType == EbtUInt || basicType == EbtBool);
1252     const bool useOpConstantNull = isConstantNullValue && !isBasicScalar;
1253     if (useOpConstantNull)
1254     {
1255         return mBuilder.getNullConstant(typeId);
1256     }
1257 
1258     if (type.isArray())
1259     {
1260         TType elementType(type);
1261         elementType.toArrayElementType();
1262 
1263         // If it's an array constant, get the constant id of each element.
1264         for (unsigned int elementIndex = 0; elementIndex < type.getOutermostArraySize();
1265              ++elementIndex)
1266         {
1267             componentIds.push_back(
1268                 createConstant(elementType, expectedBasicType, constUnion, false));
1269             constUnion += elementType.getObjectSize();
1270         }
1271     }
1272     else if (type.getBasicType() == EbtStruct)
1273     {
1274         // If it's a struct constant, get the constant id for each field.
1275         for (const TField *field : type.getStruct()->fields())
1276         {
1277             const TType *fieldType = field->type();
1278             componentIds.push_back(
1279                 createConstant(*fieldType, fieldType->getBasicType(), constUnion, false));
1280 
1281             constUnion += fieldType->getObjectSize();
1282         }
1283     }
1284     else
1285     {
1286         // Otherwise get the constant id for each component.
1287         ASSERT(expectedBasicType == EbtFloat || expectedBasicType == EbtInt ||
1288                expectedBasicType == EbtUInt || expectedBasicType == EbtBool ||
1289                expectedBasicType == EbtYuvCscStandardEXT);
1290 
1291         for (size_t component = 0; component < size; ++component, ++constUnion)
1292         {
1293             spirv::IdRef componentId;
1294 
1295             // If the constant has a different type than expected, cast it right away.
1296             TConstantUnion castConstant;
1297             bool valid = castConstant.cast(expectedBasicType, *constUnion);
1298             ASSERT(valid);
1299 
1300             switch (castConstant.getType())
1301             {
1302                 case EbtFloat:
1303                     componentId = mBuilder.getFloatConstant(castConstant.getFConst());
1304                     break;
1305                 case EbtInt:
1306                     componentId = mBuilder.getIntConstant(castConstant.getIConst());
1307                     break;
1308                 case EbtUInt:
1309                     componentId = mBuilder.getUintConstant(castConstant.getUConst());
1310                     break;
1311                 case EbtBool:
1312                     componentId = mBuilder.getBoolConstant(castConstant.getBConst());
1313                     break;
1314                 case EbtYuvCscStandardEXT:
1315                     componentId =
1316                         mBuilder.getUintConstant(castConstant.getYuvCscStandardEXTConst());
1317                     break;
1318                 default:
1319                     UNREACHABLE();
1320             }
1321             componentIds.push_back(componentId);
1322         }
1323     }
1324 
1325     // If this is a composite, create a composite constant from the components.
1326     if (type.isArray() || type.getBasicType() == EbtStruct || componentIds.size() > 1)
1327     {
1328         return createComplexConstant(type, typeId, componentIds);
1329     }
1330 
1331     // Otherwise return the sole component.
1332     ASSERT(componentIds.size() == 1);
1333     return componentIds[0];
1334 }
1335 
createComplexConstant(const TType & type,spirv::IdRef typeId,const spirv::IdRefList & parameters)1336 spirv::IdRef OutputSPIRVTraverser::createComplexConstant(const TType &type,
1337                                                          spirv::IdRef typeId,
1338                                                          const spirv::IdRefList &parameters)
1339 {
1340     ASSERT(!type.isScalar());
1341 
1342     if (type.isMatrix() && !type.isArray())
1343     {
1344         ASSERT(parameters.size() == type.getRows() * type.getCols());
1345 
1346         // Matrices are constructed from their columns.
1347         spirv::IdRefList columnIds;
1348 
1349         const spirv::IdRef columnTypeId =
1350             mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1351 
1352         for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1353         {
1354             auto columnParametersStart = parameters.begin() + columnIndex * type.getRows();
1355             spirv::IdRefList columnParameters(columnParametersStart,
1356                                               columnParametersStart + type.getRows());
1357 
1358             columnIds.push_back(mBuilder.getCompositeConstant(columnTypeId, columnParameters));
1359         }
1360 
1361         return mBuilder.getCompositeConstant(typeId, columnIds);
1362     }
1363 
1364     return mBuilder.getCompositeConstant(typeId, parameters);
1365 }
1366 
createConstructor(TIntermAggregate * node,spirv::IdRef typeId)1367 spirv::IdRef OutputSPIRVTraverser::createConstructor(TIntermAggregate *node, spirv::IdRef typeId)
1368 {
1369     const TType &type                = node->getType();
1370     const TIntermSequence &arguments = *node->getSequence();
1371     const TType &arg0Type            = arguments[0]->getAsTyped()->getType();
1372 
1373     // In some cases, constructors-with-constant values are not folded.  If the constructor is a
1374     // null value, use OpConstantNull to avoid creating a bunch of instructions.  Otherwise, the
1375     // constant is created below.
1376     if (node->isConstantNullValue())
1377     {
1378         return mBuilder.getNullConstant(typeId);
1379     }
1380 
1381     // Take each constructor argument that is visited and evaluate it as rvalue
1382     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
1383 
1384     // Constructors in GLSL can take various shapes, resulting in different translations to SPIR-V
1385     // (in each case, if the parameter doesn't match the type being constructed, it must be cast):
1386     //
1387     // - float(f): This should translate to just f
1388     // - float(v): This should translate to OpCompositeExtract %scalar %v 0
1389     // - float(m): This should translate to OpCompositeExtract %scalar %m 0 0
1390     // - vecN(f): This should translate to OpCompositeConstruct %vecN %f %f .. %f
1391     // - vecN(v1.zy, v2.x): This can technically translate to OpCompositeConstruct with two ids; the
1392     //   results of v1.zy and v2.x.  However, for simplicity it's easier to generate that
1393     //   instruction with three ids; the results of v1.z, v1.y and v2.x (see below where a matrix is
1394     //   used as parameter).
1395     // - vecN(m): This takes N components from m in column-major order (for example, vec4
1396     //   constructed out of a 4x3 matrix would select components (0,0), (0,1), (0,2) and (1,0)).
1397     //   This translates to OpCompositeConstruct with the id of the individual components extracted
1398     //   from m.
1399     // - matNxM(f): This creates a diagonal matrix.  It generates N OpCompositeConstruct
1400     //   instructions for each column (which are vecM), followed by an OpCompositeConstruct that
1401     //   constructs the final result.
1402     // - matNxM(m):
1403     //   * With m larger than NxM, this extracts a submatrix out of m.  It generates
1404     //     OpCompositeExtracts for N columns of m, followed by an OpVectorShuffle (swizzle) if the
1405     //     rows of m are more than M.  OpCompositeConstruct is used to construct the final result.
1406     //   * If m is not larger than NxM, an identity matrix is created and superimposed with m.
1407     //     OpCompositeExtract is used to extract each component of m (that is necessary), and
1408     //     together with the zero or one constants necessary used to create the columns (with
1409     //     OpCompositeConstruct).  OpCompositeConstruct is used to construct the final result.
1410     // - matNxM(v1.zy, v2.x, ...): Similarly to constructing a vector, a list of single components
1411     //   are extracted from the parameters, which are divided up and used to construct each column,
1412     //   which is finally constructed into the final result.
1413     //
1414     // Additionally, array and structs are constructed by OpCompositeConstruct followed by ids of
1415     // each parameter which must enumerate every individual element / field.
1416 
1417     // In some cases, constructors-with-constant values are not folded such as for large constants.
1418     // Some transformations may also produce constructors-with-constants instead of constants even
1419     // for basic types.  These are handled here.
1420     if (node->hasConstantValue())
1421     {
1422         if (!type.isScalar())
1423         {
1424             return createComplexConstant(node->getType(), typeId, parameters);
1425         }
1426 
1427         // If a transformation creates scalar(constant), return the constant as-is.
1428         // visitConstantUnion has already cast it to the right type.
1429         if (arguments[0]->getAsConstantUnion() != nullptr)
1430         {
1431             return parameters[0];
1432         }
1433     }
1434 
1435     if (type.isArray() || type.getStruct() != nullptr)
1436     {
1437         return createArrayOrStructConstructor(node, typeId, parameters);
1438     }
1439 
1440     // The following are simple casts:
1441     //
1442     // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
1443     // - gvecN(vN) (where the argument is a single vector with the same number of components).
1444     // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions).  Note that
1445     //   matrices are always float, so there's no actual cast and this would be a no-op.
1446     //
1447     const bool isSingleScalarCast = arguments.size() == 1 && type.isScalar() && arg0Type.isScalar();
1448     const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
1449                                     arg0Type.isVector() &&
1450                                     type.getNominalSize() == arg0Type.getNominalSize();
1451     const bool isSingleMatrixCast = arguments.size() == 1 && type.isMatrix() &&
1452                                     arg0Type.isMatrix() && type.getCols() == arg0Type.getCols() &&
1453                                     type.getRows() == arg0Type.getRows();
1454     if (isSingleScalarCast || isSingleVectorCast || isSingleMatrixCast)
1455     {
1456         return castBasicType(parameters[0], arg0Type, type, nullptr);
1457     }
1458 
1459     if (type.isScalar())
1460     {
1461         ASSERT(arguments.size() == 1);
1462         return createConstructorScalarFromNonScalar(node, typeId, parameters);
1463     }
1464 
1465     if (type.isVector())
1466     {
1467         if (arguments.size() == 1 && arg0Type.isScalar())
1468         {
1469             return createConstructorVectorFromScalar(arg0Type, type, typeId, parameters);
1470         }
1471         if (arg0Type.isMatrix())
1472         {
1473             // If the first argument is a matrix, it will always have enough components to fill an
1474             // entire vector, so it doesn't matter what's specified after it.
1475             return createConstructorVectorFromMatrix(node, typeId, parameters);
1476         }
1477         return createConstructorVectorFromMultiple(node, typeId, parameters);
1478     }
1479 
1480     ASSERT(type.isMatrix());
1481 
1482     if (arg0Type.isScalar() && arguments.size() == 1)
1483     {
1484         parameters[0] = castBasicType(parameters[0], arg0Type, type, nullptr);
1485         return createConstructorMatrixFromScalar(node, typeId, parameters);
1486     }
1487     if (arg0Type.isMatrix())
1488     {
1489         return createConstructorMatrixFromMatrix(node, typeId, parameters);
1490     }
1491     return createConstructorMatrixFromVectors(node, typeId, parameters);
1492 }
1493 
createArrayOrStructConstructor(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1494 spirv::IdRef OutputSPIRVTraverser::createArrayOrStructConstructor(
1495     TIntermAggregate *node,
1496     spirv::IdRef typeId,
1497     const spirv::IdRefList &parameters)
1498 {
1499     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1500     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1501                                    parameters);
1502     return result;
1503 }
1504 
createConstructorScalarFromNonScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1505 spirv::IdRef OutputSPIRVTraverser::createConstructorScalarFromNonScalar(
1506     TIntermAggregate *node,
1507     spirv::IdRef typeId,
1508     const spirv::IdRefList &parameters)
1509 {
1510     ASSERT(parameters.size() == 1);
1511     const TType &type     = node->getType();
1512     const TType &arg0Type = node->getChildNode(0)->getAsTyped()->getType();
1513 
1514     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(type));
1515 
1516     spirv::LiteralIntegerList indices = {spirv::LiteralInteger(0)};
1517     if (arg0Type.isMatrix())
1518     {
1519         indices.push_back(spirv::LiteralInteger(0));
1520     }
1521 
1522     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1523                                  mBuilder.getBasicTypeId(arg0Type.getBasicType(), 1), result,
1524                                  parameters[0], indices);
1525 
1526     TType arg0TypeAsScalar(arg0Type);
1527     arg0TypeAsScalar.toComponentType();
1528 
1529     return castBasicType(result, arg0TypeAsScalar, type, nullptr);
1530 }
1531 
createConstructorVectorFromScalar(const TType & parameterType,const TType & expectedType,spirv::IdRef typeId,const spirv::IdRefList & parameters)1532 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalar(
1533     const TType &parameterType,
1534     const TType &expectedType,
1535     spirv::IdRef typeId,
1536     const spirv::IdRefList &parameters)
1537 {
1538     // vecN(f) translates to OpCompositeConstruct %vecN %f ... %f
1539     ASSERT(parameters.size() == 1);
1540 
1541     const spirv::IdRef castParameter =
1542         castBasicType(parameters[0], parameterType, expectedType, nullptr);
1543 
1544     spirv::IdRefList replicatedParameter(expectedType.getNominalSize(), castParameter);
1545 
1546     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(parameterType));
1547     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1548                                    replicatedParameter);
1549     return result;
1550 }
1551 
createConstructorVectorFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1552 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMatrix(
1553     TIntermAggregate *node,
1554     spirv::IdRef typeId,
1555     const spirv::IdRefList &parameters)
1556 {
1557     // vecN(m) translates to OpCompositeConstruct %vecN %m[0][0] %m[0][1] ...
1558     spirv::IdRefList extractedComponents;
1559     extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
1560 
1561     // Construct the vector with the basic type of the argument, and cast it at end if needed.
1562     ASSERT(parameters.size() == 1);
1563     const TType &arg0Type     = node->getChildNode(0)->getAsTyped()->getType();
1564     const TType &expectedType = node->getType();
1565 
1566     spirv::IdRef argumentTypeId = typeId;
1567     TType arg0TypeAsVector(arg0Type);
1568     arg0TypeAsVector.setPrimarySize(node->getType().getNominalSize());
1569     arg0TypeAsVector.setSecondarySize(1);
1570 
1571     if (arg0Type.getBasicType() != expectedType.getBasicType())
1572     {
1573         argumentTypeId = mBuilder.getTypeData(arg0TypeAsVector, {}).id;
1574     }
1575 
1576     spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1577     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), argumentTypeId, result,
1578                                    extractedComponents);
1579 
1580     if (arg0Type.getBasicType() != expectedType.getBasicType())
1581     {
1582         result = castBasicType(result, arg0TypeAsVector, expectedType, nullptr);
1583     }
1584 
1585     return result;
1586 }
1587 
createConstructorVectorFromMultiple(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1588 spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromMultiple(
1589     TIntermAggregate *node,
1590     spirv::IdRef typeId,
1591     const spirv::IdRefList &parameters)
1592 {
1593     const TType &type = node->getType();
1594     // vecN(v1.zy, v2.x) translates to OpCompositeConstruct %vecN %v1.z %v1.y %v2.x
1595     spirv::IdRefList extractedComponents;
1596     extractComponents(node, type.getNominalSize(), parameters, &extractedComponents);
1597 
1598     // Handle the case where a matrix is used in the constructor anywhere but the first place.  In
1599     // that case, the components extracted from the matrix might need casting to the right type.
1600     const TIntermSequence &arguments = *node->getSequence();
1601     for (size_t argumentIndex = 0, componentIndex = 0;
1602          argumentIndex < arguments.size() && componentIndex < extractedComponents.size();
1603          ++argumentIndex)
1604     {
1605         TIntermNode *argument     = arguments[argumentIndex];
1606         const TType &argumentType = argument->getAsTyped()->getType();
1607         if (argumentType.isScalar() || argumentType.isVector())
1608         {
1609             // extractComponents already casts scalar and vector components.
1610             componentIndex += argumentType.getNominalSize();
1611             continue;
1612         }
1613 
1614         TType componentType(argumentType);
1615         componentType.toComponentType();
1616 
1617         for (uint8_t columnIndex = 0;
1618              columnIndex < argumentType.getCols() && componentIndex < extractedComponents.size();
1619              ++columnIndex)
1620         {
1621             for (uint8_t rowIndex = 0;
1622                  rowIndex < argumentType.getRows() && componentIndex < extractedComponents.size();
1623                  ++rowIndex, ++componentIndex)
1624             {
1625                 extractedComponents[componentIndex] = castBasicType(
1626                     extractedComponents[componentIndex], componentType, type, nullptr);
1627             }
1628         }
1629 
1630         // Matrices all have enough components to fill a vector, so it's impossible to need to visit
1631         // any other arguments that may come after.
1632         ASSERT(componentIndex == extractedComponents.size());
1633     }
1634 
1635     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
1636     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1637                                    extractedComponents);
1638     return result;
1639 }
1640 
createConstructorMatrixFromScalar(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1641 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromScalar(
1642     TIntermAggregate *node,
1643     spirv::IdRef typeId,
1644     const spirv::IdRefList &parameters)
1645 {
1646     // matNxM(f) translates to
1647     //
1648     //     %c0 = OpCompositeConstruct %vecM %f %zero %zero ..
1649     //     %c1 = OpCompositeConstruct %vecM %zero %f %zero ..
1650     //     %c2 = OpCompositeConstruct %vecM %zero %zero %f ..
1651     //     ...
1652     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1653 
1654     const TType &type           = node->getType();
1655     const spirv::IdRef scalarId = parameters[0];
1656     spirv::IdRef zeroId;
1657 
1658     SpirvDecorations decorations = mBuilder.getDecorations(type);
1659 
1660     switch (type.getBasicType())
1661     {
1662         case EbtFloat:
1663             zeroId = mBuilder.getFloatConstant(0);
1664             break;
1665         case EbtInt:
1666             zeroId = mBuilder.getIntConstant(0);
1667             break;
1668         case EbtUInt:
1669             zeroId = mBuilder.getUintConstant(0);
1670             break;
1671         case EbtBool:
1672             zeroId = mBuilder.getBoolConstant(0);
1673             break;
1674         default:
1675             UNREACHABLE();
1676     }
1677 
1678     spirv::IdRefList componentIds(type.getRows(), zeroId);
1679     spirv::IdRefList columnIds;
1680 
1681     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1682 
1683     for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1684     {
1685         columnIds.push_back(mBuilder.getNewId(decorations));
1686 
1687         // Place the scalar at the correct index (diagonal of the matrix, i.e. row == col).
1688         if (columnIndex < type.getRows())
1689         {
1690             componentIds[columnIndex] = scalarId;
1691         }
1692         if (columnIndex > 0 && columnIndex <= type.getRows())
1693         {
1694             componentIds[columnIndex - 1] = zeroId;
1695         }
1696 
1697         // Create the column.
1698         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1699                                        columnIds.back(), componentIds);
1700     }
1701 
1702     // Create the matrix out of the columns.
1703     const spirv::IdRef result = mBuilder.getNewId(decorations);
1704     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1705                                    columnIds);
1706     return result;
1707 }
1708 
createConstructorMatrixFromVectors(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1709 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromVectors(
1710     TIntermAggregate *node,
1711     spirv::IdRef typeId,
1712     const spirv::IdRefList &parameters)
1713 {
1714     // matNxM(v1.zy, v2.x, ...) translates to:
1715     //
1716     //     %c0 = OpCompositeConstruct %vecM %v1.z %v1.y %v2.x ..
1717     //     ...
1718     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1719 
1720     const TType &type = node->getType();
1721 
1722     SpirvDecorations decorations = mBuilder.getDecorations(type);
1723 
1724     spirv::IdRefList extractedComponents;
1725     extractComponents(node, type.getCols() * type.getRows(), parameters, &extractedComponents);
1726 
1727     spirv::IdRefList columnIds;
1728 
1729     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1730 
1731     // Chunk up the extracted components by column and construct intermediary vectors.
1732     for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1733     {
1734         columnIds.push_back(mBuilder.getNewId(decorations));
1735 
1736         auto componentsStart = extractedComponents.begin() + columnIndex * type.getRows();
1737         const spirv::IdRefList componentIds(componentsStart, componentsStart + type.getRows());
1738 
1739         // Create the column.
1740         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1741                                        columnIds.back(), componentIds);
1742     }
1743 
1744     const spirv::IdRef result = mBuilder.getNewId(decorations);
1745     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1746                                    columnIds);
1747     return result;
1748 }
1749 
createConstructorMatrixFromMatrix(TIntermAggregate * node,spirv::IdRef typeId,const spirv::IdRefList & parameters)1750 spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromMatrix(
1751     TIntermAggregate *node,
1752     spirv::IdRef typeId,
1753     const spirv::IdRefList &parameters)
1754 {
1755     // matNxM(m) translates to:
1756     //
1757     // - If m is SxR where S>=N and R>=M:
1758     //
1759     //     %c0 = OpCompositeExtract %vecR %m 0
1760     //     %c1 = OpCompositeExtract %vecR %m 1
1761     //     ...
1762     //     // If R (column size of m) != M, OpVectorShuffle to extract M components out of %ci.
1763     //     ...
1764     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
1765     //
1766     // - Otherwise, an identity matrix is created and superimposed by m:
1767     //
1768     //     %c0 = OpCompositeConstruct %vecM %m[0][0] %m[0][1] %0 %0
1769     //     %c1 = OpCompositeConstruct %vecM %m[1][0] %m[1][1] %0 %0
1770     //     %c2 = OpCompositeConstruct %vecM %m[2][0] %m[2][1] %1 %0
1771     //     %c3 = OpCompositeConstruct %vecM       %0       %0 %0 %1
1772     //     %m  = OpCompositeConstruct %matNxM %c0 %c1 %c2 %c3
1773 
1774     const TType &type          = node->getType();
1775     const TType &parameterType = (*node->getSequence())[0]->getAsTyped()->getType();
1776 
1777     SpirvDecorations decorations = mBuilder.getDecorations(type);
1778 
1779     ASSERT(parameters.size() == 1);
1780 
1781     spirv::IdRefList columnIds;
1782 
1783     const spirv::IdRef columnTypeId = mBuilder.getBasicTypeId(type.getBasicType(), type.getRows());
1784 
1785     if (parameterType.getCols() >= type.getCols() && parameterType.getRows() >= type.getRows())
1786     {
1787         // If the parameter is a larger matrix than the constructor type, extract the columns
1788         // directly and potentially swizzle them.
1789         TType paramColumnType(parameterType);
1790         paramColumnType.toMatrixColumnType();
1791         const spirv::IdRef paramColumnTypeId = mBuilder.getTypeData(paramColumnType, {}).id;
1792 
1793         const bool needsSwizzle           = parameterType.getRows() > type.getRows();
1794         spirv::LiteralIntegerList swizzle = {spirv::LiteralInteger(0), spirv::LiteralInteger(1),
1795                                              spirv::LiteralInteger(2), spirv::LiteralInteger(3)};
1796         swizzle.resize_down(type.getRows());
1797 
1798         for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1799         {
1800             // Extract the column.
1801             const spirv::IdRef parameterColumnId = mBuilder.getNewId(decorations);
1802             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), paramColumnTypeId,
1803                                          parameterColumnId, parameters[0],
1804                                          {spirv::LiteralInteger(columnIndex)});
1805 
1806             // If the column has too many components, select the appropriate number of components.
1807             spirv::IdRef constructorColumnId = parameterColumnId;
1808             if (needsSwizzle)
1809             {
1810                 constructorColumnId = mBuilder.getNewId(decorations);
1811                 spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1812                                           constructorColumnId, parameterColumnId, parameterColumnId,
1813                                           swizzle);
1814             }
1815 
1816             columnIds.push_back(constructorColumnId);
1817         }
1818     }
1819     else
1820     {
1821         // Otherwise create an identity matrix and fill in the components that can be taken from the
1822         // given parameter.
1823         TType paramComponentType(parameterType);
1824         paramComponentType.toComponentType();
1825         const spirv::IdRef paramComponentTypeId = mBuilder.getTypeData(paramComponentType, {}).id;
1826 
1827         for (uint8_t columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
1828         {
1829             spirv::IdRefList componentIds;
1830 
1831             for (uint8_t componentIndex = 0; componentIndex < type.getRows(); ++componentIndex)
1832             {
1833                 // Take the component from the constructor parameter if possible.
1834                 spirv::IdRef componentId;
1835                 if (componentIndex < parameterType.getRows() &&
1836                     columnIndex < parameterType.getCols())
1837                 {
1838                     componentId = mBuilder.getNewId(decorations);
1839                     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1840                                                  paramComponentTypeId, componentId, parameters[0],
1841                                                  {spirv::LiteralInteger(columnIndex),
1842                                                   spirv::LiteralInteger(componentIndex)});
1843                 }
1844                 else
1845                 {
1846                     const bool isOnDiagonal = columnIndex == componentIndex;
1847                     switch (type.getBasicType())
1848                     {
1849                         case EbtFloat:
1850                             componentId = mBuilder.getFloatConstant(isOnDiagonal ? 1.0f : 0.0f);
1851                             break;
1852                         case EbtInt:
1853                             componentId = mBuilder.getIntConstant(isOnDiagonal ? 1 : 0);
1854                             break;
1855                         case EbtUInt:
1856                             componentId = mBuilder.getUintConstant(isOnDiagonal ? 1 : 0);
1857                             break;
1858                         case EbtBool:
1859                             componentId = mBuilder.getBoolConstant(isOnDiagonal);
1860                             break;
1861                         default:
1862                             UNREACHABLE();
1863                     }
1864                 }
1865 
1866                 componentIds.push_back(componentId);
1867             }
1868 
1869             // Create the column vector.
1870             columnIds.push_back(mBuilder.getNewId(decorations));
1871             spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
1872                                            columnIds.back(), componentIds);
1873         }
1874     }
1875 
1876     const spirv::IdRef result = mBuilder.getNewId(decorations);
1877     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
1878                                    columnIds);
1879     return result;
1880 }
1881 
loadAllParams(TIntermOperator * node,size_t skipCount,spirv::IdRefList * paramTypeIds)1882 spirv::IdRefList OutputSPIRVTraverser::loadAllParams(TIntermOperator *node,
1883                                                      size_t skipCount,
1884                                                      spirv::IdRefList *paramTypeIds)
1885 {
1886     const size_t parameterCount = node->getChildCount();
1887     spirv::IdRefList parameters;
1888 
1889     for (size_t paramIndex = 0; paramIndex + skipCount < parameterCount; ++paramIndex)
1890     {
1891         // Take each parameter that is visited and evaluate it as rvalue
1892         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
1893 
1894         spirv::IdRef paramTypeId;
1895         const spirv::IdRef paramValue = accessChainLoad(
1896             &param, node->getChildNode(paramIndex)->getAsTyped()->getType(), &paramTypeId);
1897 
1898         parameters.push_back(paramValue);
1899         if (paramTypeIds)
1900         {
1901             paramTypeIds->push_back(paramTypeId);
1902         }
1903     }
1904 
1905     return parameters;
1906 }
1907 
extractComponents(TIntermAggregate * node,size_t componentCount,const spirv::IdRefList & parameters,spirv::IdRefList * extractedComponentsOut)1908 void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node,
1909                                              size_t componentCount,
1910                                              const spirv::IdRefList &parameters,
1911                                              spirv::IdRefList *extractedComponentsOut)
1912 {
1913     // A helper function that takes the list of parameters passed to a constructor (which may have
1914     // more components than necessary) and extracts the first componentCount components.
1915     const TIntermSequence &arguments = *node->getSequence();
1916 
1917     const SpirvDecorations decorations = mBuilder.getDecorations(node->getType());
1918     const TType &expectedType          = node->getType();
1919 
1920     ASSERT(arguments.size() == parameters.size());
1921 
1922     for (size_t argumentIndex = 0;
1923          argumentIndex < arguments.size() && extractedComponentsOut->size() < componentCount;
1924          ++argumentIndex)
1925     {
1926         TIntermNode *argument          = arguments[argumentIndex];
1927         const TType &argumentType      = argument->getAsTyped()->getType();
1928         const spirv::IdRef parameterId = parameters[argumentIndex];
1929 
1930         if (argumentType.isScalar())
1931         {
1932             // For scalar parameters, there's nothing to do other than a potential cast.
1933             const spirv::IdRef castParameterId =
1934                 argument->getAsConstantUnion()
1935                     ? parameterId
1936                     : castBasicType(parameterId, argumentType, expectedType, nullptr);
1937             extractedComponentsOut->push_back(castParameterId);
1938             continue;
1939         }
1940         if (argumentType.isVector())
1941         {
1942             TType componentType(argumentType);
1943             componentType.toComponentType();
1944             componentType.setBasicType(expectedType.getBasicType());
1945             const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1946 
1947             // Cast the whole vector parameter in one go.
1948             const spirv::IdRef castParameterId =
1949                 argument->getAsConstantUnion()
1950                     ? parameterId
1951                     : castBasicType(parameterId, argumentType, expectedType, nullptr);
1952 
1953             // For vector parameters, take components out of the vector one by one.
1954             for (uint8_t componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
1955                                              extractedComponentsOut->size() < componentCount;
1956                  ++componentIndex)
1957             {
1958                 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1959                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
1960                                              componentTypeId, componentId, castParameterId,
1961                                              {spirv::LiteralInteger(componentIndex)});
1962 
1963                 extractedComponentsOut->push_back(componentId);
1964             }
1965             continue;
1966         }
1967 
1968         ASSERT(argumentType.isMatrix());
1969 
1970         TType componentType(argumentType);
1971         componentType.toComponentType();
1972         const spirv::IdRef componentTypeId = mBuilder.getTypeData(componentType, {}).id;
1973 
1974         // For matrix parameters, take components out of the matrix one by one in column-major
1975         // order.  No cast is done here; it would only be required for vector constructors with
1976         // matrix parameters, in which case the resulting vector is cast in the end.
1977         for (uint8_t columnIndex = 0; columnIndex < argumentType.getCols() &&
1978                                       extractedComponentsOut->size() < componentCount;
1979              ++columnIndex)
1980         {
1981             for (uint8_t componentIndex = 0; componentIndex < argumentType.getRows() &&
1982                                              extractedComponentsOut->size() < componentCount;
1983                  ++componentIndex)
1984             {
1985                 const spirv::IdRef componentId = mBuilder.getNewId(decorations);
1986                 spirv::WriteCompositeExtract(
1987                     mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId, componentId,
1988                     parameterId,
1989                     {spirv::LiteralInteger(columnIndex), spirv::LiteralInteger(componentIndex)});
1990 
1991                 extractedComponentsOut->push_back(componentId);
1992             }
1993         }
1994     }
1995 }
1996 
startShortCircuit(TIntermBinary * node)1997 void OutputSPIRVTraverser::startShortCircuit(TIntermBinary *node)
1998 {
1999     // Emulate && and || as such:
2000     //
2001     //   || => if (!left) result = right
2002     //   && => if ( left) result = right
2003     //
2004     // When this function is called, |left| has already been visited, so it creates the appropriate
2005     // |if| construct in preparation for visiting |right|.
2006 
2007     // Load |left| and replace the access chain with an rvalue that's the result.
2008     spirv::IdRef typeId;
2009     const spirv::IdRef left =
2010         accessChainLoad(&mNodeData.back(), node->getLeft()->getType(), &typeId);
2011     nodeDataInitRValue(&mNodeData.back(), left, typeId);
2012 
2013     // Keep the id of the block |left| was evaluated in.
2014     mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
2015 
2016     // Two blocks necessary, one for the |if| block, and one for the merge block.
2017     mBuilder.startConditional(2, false, false);
2018 
2019     // Generate the branch instructions.
2020     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
2021 
2022     const spirv::IdRef mergeBlock = conditional->blockIds.back();
2023     const spirv::IdRef ifBlock    = conditional->blockIds.front();
2024     const spirv::IdRef trueBlock  = node->getOp() == EOpLogicalAnd ? ifBlock : mergeBlock;
2025     const spirv::IdRef falseBlock = node->getOp() == EOpLogicalOr ? ifBlock : mergeBlock;
2026 
2027     // Note that no logical not is necessary.  For ||, the branch will target the merge block in the
2028     // true case.
2029     mBuilder.writeBranchConditional(left, trueBlock, falseBlock, mergeBlock);
2030 }
2031 
endShortCircuit(TIntermBinary * node,spirv::IdRef * typeId)2032 spirv::IdRef OutputSPIRVTraverser::endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId)
2033 {
2034     // Load the right hand side.
2035     const spirv::IdRef right =
2036         accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
2037     mNodeData.pop_back();
2038 
2039     // Get the id of the block |right| is evaluated in.
2040     const spirv::IdRef rightBlockId = mBuilder.getSpirvCurrentFunctionBlockId();
2041 
2042     // And the cached id of the block |left| is evaluated in.
2043     ASSERT(mNodeData.back().idList.size() == 1);
2044     const spirv::IdRef leftBlockId = mNodeData.back().idList[0].id;
2045     mNodeData.back().idList.clear();
2046 
2047     // Move on to the merge block.
2048     mBuilder.writeBranchConditionalBlockEnd();
2049 
2050     // Pop from the conditional stack.
2051     mBuilder.endConditional();
2052 
2053     // Get the previously loaded result of the left hand side.
2054     *typeId                 = mNodeData.back().accessChain.baseTypeId;
2055     const spirv::IdRef left = mNodeData.back().baseId;
2056 
2057     // Create an OpPhi instruction that selects either the |left| or |right| based on which block
2058     // was traversed.
2059     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2060 
2061     spirv::WritePhi(
2062         mBuilder.getSpirvCurrentFunctionBlock(), *typeId, result,
2063         {spirv::PairIdRefIdRef{left, leftBlockId}, spirv::PairIdRefIdRef{right, rightBlockId}});
2064 
2065     return result;
2066 }
2067 
createFunctionCall(TIntermAggregate * node,spirv::IdRef resultTypeId)2068 spirv::IdRef OutputSPIRVTraverser::createFunctionCall(TIntermAggregate *node,
2069                                                       spirv::IdRef resultTypeId)
2070 {
2071     const TFunction *function = node->getFunction();
2072     ASSERT(function);
2073 
2074     ASSERT(mFunctionIdMap.count(function) > 0);
2075     const spirv::IdRef functionId = mFunctionIdMap[function].functionId;
2076 
2077     // Get the list of parameters passed to the function.  The function parameters can only be
2078     // memory variables, or if the function argument is |const|, an rvalue.
2079     //
2080     // For opaque uniforms, pass it directly as lvalue.
2081     //
2082     // For in variables:
2083     //
2084     // - If the parameter is const, pass it directly as rvalue, otherwise
2085     // - Write it to a temp variable first and pass that.
2086     //
2087     // For out variables:
2088     //
2089     // - Pass a temporary variable.  After the function call, copy that variable to the parameter.
2090     //
2091     // For inout variables:
2092     //
2093     // - Write the parameter to a temp variable and pass that.  After the function call, copy that
2094     //   variable back to the parameter.
2095     //
2096     // Note that in GLSL, in parameters are considered "copied" to the function.  In SPIR-V, every
2097     // parameter is implicitly inout.  If a function takes an in parameter and modifies it, the
2098     // caller has to ensure that it calls the function with a copy.  Currently, the functions don't
2099     // track whether an in parameter is modified, so we conservatively assume it is.  Even for out
2100     // and inout parameters, GLSL expects each function to operate on their local copy until the end
2101     // of the function; this has observable side effects if the out variable aliases another
2102     // variable the function has access to (another out variable, a global variable etc).
2103     //
2104     const size_t parameterCount = node->getChildCount();
2105     spirv::IdRefList parameters;
2106     spirv::IdRefList tempVarIds(parameterCount);
2107     spirv::IdRefList tempVarTypeIds(parameterCount);
2108 
2109     for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2110     {
2111         const TType &paramType           = function->getParam(paramIndex)->getType();
2112         const TType &argType             = node->getChildNode(paramIndex)->getAsTyped()->getType();
2113         const TQualifier &paramQualifier = paramType.getQualifier();
2114         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2115 
2116         spirv::IdRef paramValue;
2117 
2118         if (paramQualifier == EvqParamConst)
2119         {
2120             // |const| parameters are passed as rvalue.
2121             paramValue = accessChainLoad(&param, argType, nullptr);
2122         }
2123         else if (IsOpaqueType(paramType.getBasicType()))
2124         {
2125             // Opaque uniforms are passed by pointer.
2126             paramValue = accessChainCollapse(&param);
2127         }
2128         else
2129         {
2130             ASSERT(paramQualifier == EvqParamIn || paramQualifier == EvqParamOut ||
2131                    paramQualifier == EvqParamInOut);
2132 
2133             // Need to create a temp variable and pass that.
2134             tempVarTypeIds[paramIndex] = mBuilder.getTypeData(paramType, {}).id;
2135             tempVarIds[paramIndex]     = mBuilder.declareVariable(
2136                 tempVarTypeIds[paramIndex], spv::StorageClassFunction,
2137                 mBuilder.getDecorations(argType), nullptr, "param", nullptr);
2138 
2139             // If it's an in or inout parameter, the temp variable needs to be initialized with the
2140             // value of the parameter first.
2141             if (paramQualifier == EvqParamIn || paramQualifier == EvqParamInOut)
2142             {
2143                 paramValue = accessChainLoad(&param, argType, nullptr);
2144                 spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), tempVarIds[paramIndex],
2145                                   paramValue, nullptr);
2146             }
2147 
2148             paramValue = tempVarIds[paramIndex];
2149         }
2150 
2151         parameters.push_back(paramValue);
2152     }
2153 
2154     // Make the actual function call.
2155     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
2156     spirv::WriteFunctionCall(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
2157                              functionId, parameters);
2158 
2159     // Copy from the out and inout temp variables back to the original parameters.
2160     for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
2161     {
2162         if (!tempVarIds[paramIndex].valid())
2163         {
2164             continue;
2165         }
2166 
2167         const TType &paramType           = function->getParam(paramIndex)->getType();
2168         const TType &argType             = node->getChildNode(paramIndex)->getAsTyped()->getType();
2169         const TQualifier &paramQualifier = paramType.getQualifier();
2170         NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
2171 
2172         if (paramQualifier == EvqParamIn)
2173         {
2174             continue;
2175         }
2176 
2177         // Copy from the temp variable to the parameter.
2178         NodeData tempVarData;
2179         nodeDataInitLValue(&tempVarData, tempVarIds[paramIndex], tempVarTypeIds[paramIndex],
2180                            spv::StorageClassFunction, {});
2181         const spirv::IdRef tempVarValue = accessChainLoad(&tempVarData, argType, nullptr);
2182         accessChainStore(&param, tempVarValue, function->getParam(paramIndex)->getType());
2183     }
2184 
2185     return result;
2186 }
2187 
visitArrayLength(TIntermUnary * node)2188 void OutputSPIRVTraverser::visitArrayLength(TIntermUnary *node)
2189 {
2190     // .length() on sized arrays is already constant folded, so this operation only applies to
2191     // ssbo[N].last_member.length().  OpArrayLength takes the ssbo block *pointer* and the field
2192     // index of last_member, so those need to be extracted from the access chain.  Additionally,
2193     // OpArrayLength produces an unsigned int while GLSL produces an int, so a final cast is
2194     // necessary.
2195 
2196     // Inspect the children.  There are two possibilities:
2197     //
2198     // - last_member.length(): In this case, the id of the nameless ssbo is used.
2199     // - ssbo.last_member.length(): In this case, the id of the variable |ssbo| itself is used.
2200     // - ssbo[N][M].last_member.length(): In this case, the access chain |ssbo N M| is used.
2201     //
2202     // We can't visit the child in its entirety as it will create the access chain |ssbo N M field|
2203     // which is not useful.
2204 
2205     spirv::IdRef accessChainId;
2206     spirv::LiteralInteger fieldIndex;
2207 
2208     if (node->getOperand()->getAsSymbolNode())
2209     {
2210         // If the operand is a symbol referencing the last member of a nameless interface block,
2211         // visit the symbol to get the id of the interface block.
2212         node->getOperand()->getAsSymbolNode()->traverse(this);
2213 
2214         // The access chain must only include the base id + one literal field index.
2215         ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
2216 
2217         accessChainId = mNodeData.back().baseId;
2218         fieldIndex    = mNodeData.back().idList.back().literal;
2219     }
2220     else
2221     {
2222         // Otherwise make sure not to traverse the field index selection node so that the access
2223         // chain would not include it.
2224         TIntermBinary *fieldSelectionNode = node->getOperand()->getAsBinaryNode();
2225         ASSERT(fieldSelectionNode && fieldSelectionNode->getOp() == EOpIndexDirectInterfaceBlock);
2226 
2227         TIntermTyped *interfaceBlockExpression = fieldSelectionNode->getLeft();
2228         TIntermConstantUnion *indexNode = fieldSelectionNode->getRight()->getAsConstantUnion();
2229         ASSERT(indexNode);
2230 
2231         // Visit the expression.
2232         interfaceBlockExpression->traverse(this);
2233 
2234         accessChainId = accessChainCollapse(&mNodeData.back());
2235         fieldIndex    = spirv::LiteralInteger(indexNode->getIConst(0));
2236     }
2237 
2238     // Get the int and uint type ids.
2239     const spirv::IdRef intTypeId  = mBuilder.getBasicTypeId(EbtInt, 1);
2240     const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
2241 
2242     // Generate the instruction.
2243     const spirv::IdRef resultId = mBuilder.getNewId({});
2244     spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
2245                             accessChainId, fieldIndex);
2246 
2247     // Cast to int.
2248     const spirv::IdRef castResultId = mBuilder.getNewId({});
2249     spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId, resultId);
2250 
2251     // Replace the access chain with an rvalue that's the result.
2252     nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
2253 }
2254 
IsShortCircuitNeeded(TIntermOperator * node)2255 bool IsShortCircuitNeeded(TIntermOperator *node)
2256 {
2257     TOperator op = node->getOp();
2258 
2259     // Short circuit is only necessary for && and ||.
2260     if (op != EOpLogicalAnd && op != EOpLogicalOr)
2261     {
2262         return false;
2263     }
2264 
2265     ASSERT(node->getChildCount() == 2);
2266 
2267     // If the right hand side does not have side effects, short-circuiting is unnecessary.
2268     // TODO: experiment with the performance of OpLogicalAnd/Or vs short-circuit based on the
2269     // complexity of the right hand side expression.  We could potentially only allow
2270     // OpLogicalAnd/Or if the right hand side is a constant or an access chain and have more complex
2271     // expressions be placed inside an if block.  http://anglebug.com/40096715
2272     return node->getChildNode(1)->getAsTyped()->hasSideEffects();
2273 }
2274 
2275 using WriteUnaryOp      = void (*)(spirv::Blob *blob,
2276                               spirv::IdResultType idResultType,
2277                               spirv::IdResult idResult,
2278                               spirv::IdRef operand);
2279 using WriteBinaryOp     = void (*)(spirv::Blob *blob,
2280                                spirv::IdResultType idResultType,
2281                                spirv::IdResult idResult,
2282                                spirv::IdRef operand1,
2283                                spirv::IdRef operand2);
2284 using WriteTernaryOp    = void (*)(spirv::Blob *blob,
2285                                 spirv::IdResultType idResultType,
2286                                 spirv::IdResult idResult,
2287                                 spirv::IdRef operand1,
2288                                 spirv::IdRef operand2,
2289                                 spirv::IdRef operand3);
2290 using WriteQuaternaryOp = void (*)(spirv::Blob *blob,
2291                                    spirv::IdResultType idResultType,
2292                                    spirv::IdResult idResult,
2293                                    spirv::IdRef operand1,
2294                                    spirv::IdRef operand2,
2295                                    spirv::IdRef operand3,
2296                                    spirv::IdRef operand4);
2297 using WriteAtomicOp     = void (*)(spirv::Blob *blob,
2298                                spirv::IdResultType idResultType,
2299                                spirv::IdResult idResult,
2300                                spirv::IdRef pointer,
2301                                spirv::IdScope scope,
2302                                spirv::IdMemorySemantics semantics,
2303                                spirv::IdRef value);
2304 
visitOperator(TIntermOperator * node,spirv::IdRef resultTypeId)2305 spirv::IdRef OutputSPIRVTraverser::visitOperator(TIntermOperator *node, spirv::IdRef resultTypeId)
2306 {
2307     // Handle special groups.
2308     const TOperator op = node->getOp();
2309     if (op == EOpEqual || op == EOpNotEqual)
2310     {
2311         return createCompare(node, resultTypeId);
2312     }
2313     if (BuiltInGroup::IsAtomicMemory(op) || BuiltInGroup::IsImageAtomic(op))
2314     {
2315         return createAtomicBuiltIn(node, resultTypeId);
2316     }
2317     if (BuiltInGroup::IsImage(op) || BuiltInGroup::IsTexture(op))
2318     {
2319         return createImageTextureBuiltIn(node, resultTypeId);
2320     }
2321     if (op == EOpSubpassLoad)
2322     {
2323         return createSubpassLoadBuiltIn(node, resultTypeId);
2324     }
2325     if (BuiltInGroup::IsInterpolationFS(op))
2326     {
2327         return createInterpolate(node, resultTypeId);
2328     }
2329 
2330     const size_t childCount   = node->getChildCount();
2331     TIntermTyped *firstChild  = node->getChildNode(0)->getAsTyped();
2332     TIntermTyped *secondChild = childCount > 1 ? node->getChildNode(1)->getAsTyped() : nullptr;
2333 
2334     const TType &firstOperandType = firstChild->getType();
2335     const TBasicType basicType    = firstOperandType.getBasicType();
2336     const bool isFloat            = basicType == EbtFloat;
2337     const bool isUnsigned         = basicType == EbtUInt;
2338     const bool isBool             = basicType == EbtBool;
2339     // Whether this is a pre/post increment/decrement operator.
2340     bool isIncrementOrDecrement = false;
2341     // Whether the operation needs to be applied column by column.
2342     bool operateOnColumns =
2343         childCount == 2 && (firstChild->getType().isMatrix() || secondChild->getType().isMatrix());
2344     // Whether the operands need to be swapped in the (binary) instruction
2345     bool binarySwapOperands = false;
2346     // Whether the instruction is matrix/scalar.  This is implemented with matrix*(1/scalar).
2347     bool binaryInvertSecondParameter = false;
2348     // Whether the scalar operand needs to be extended to match the other operand which is a vector
2349     // (in a binary or extended operation).
2350     bool extendScalarToVector = true;
2351     // Some built-ins have out parameters at the end of the list of parameters.
2352     size_t lvalueCount = 0;
2353 
2354     WriteUnaryOp writeUnaryOp           = nullptr;
2355     WriteBinaryOp writeBinaryOp         = nullptr;
2356     WriteTernaryOp writeTernaryOp       = nullptr;
2357     WriteQuaternaryOp writeQuaternaryOp = nullptr;
2358 
2359     // Some operators are implemented with an extended instruction.
2360     spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
2361 
2362     switch (op)
2363     {
2364         case EOpNegative:
2365             operateOnColumns = firstOperandType.isMatrix();
2366             if (isFloat)
2367                 writeUnaryOp = spirv::WriteFNegate;
2368             else
2369                 writeUnaryOp = spirv::WriteSNegate;
2370             break;
2371         case EOpPositive:
2372             // This is a noop.
2373             return accessChainLoad(&mNodeData.back(), firstOperandType, nullptr);
2374 
2375         case EOpLogicalNot:
2376         case EOpNotComponentWise:
2377             writeUnaryOp = spirv::WriteLogicalNot;
2378             break;
2379         case EOpBitwiseNot:
2380             writeUnaryOp = spirv::WriteNot;
2381             break;
2382 
2383         case EOpPostIncrement:
2384         case EOpPreIncrement:
2385             isIncrementOrDecrement = true;
2386             operateOnColumns       = firstOperandType.isMatrix();
2387             if (isFloat)
2388                 writeBinaryOp = spirv::WriteFAdd;
2389             else
2390                 writeBinaryOp = spirv::WriteIAdd;
2391             break;
2392         case EOpPostDecrement:
2393         case EOpPreDecrement:
2394             isIncrementOrDecrement = true;
2395             operateOnColumns       = firstOperandType.isMatrix();
2396             if (isFloat)
2397                 writeBinaryOp = spirv::WriteFSub;
2398             else
2399                 writeBinaryOp = spirv::WriteISub;
2400             break;
2401 
2402         case EOpAdd:
2403         case EOpAddAssign:
2404             if (isFloat)
2405                 writeBinaryOp = spirv::WriteFAdd;
2406             else
2407                 writeBinaryOp = spirv::WriteIAdd;
2408             break;
2409         case EOpSub:
2410         case EOpSubAssign:
2411             if (isFloat)
2412                 writeBinaryOp = spirv::WriteFSub;
2413             else
2414                 writeBinaryOp = spirv::WriteISub;
2415             break;
2416         case EOpMul:
2417         case EOpMulAssign:
2418         case EOpMatrixCompMult:
2419             if (isFloat)
2420                 writeBinaryOp = spirv::WriteFMul;
2421             else
2422                 writeBinaryOp = spirv::WriteIMul;
2423             break;
2424         case EOpDiv:
2425         case EOpDivAssign:
2426             if (isFloat)
2427             {
2428                 if (firstOperandType.isMatrix() && secondChild->getType().isScalar())
2429                 {
2430                     writeBinaryOp               = spirv::WriteMatrixTimesScalar;
2431                     binaryInvertSecondParameter = true;
2432                     operateOnColumns            = false;
2433                     extendScalarToVector        = false;
2434                 }
2435                 else
2436                 {
2437                     writeBinaryOp = spirv::WriteFDiv;
2438                 }
2439             }
2440             else if (isUnsigned)
2441                 writeBinaryOp = spirv::WriteUDiv;
2442             else
2443                 writeBinaryOp = spirv::WriteSDiv;
2444             break;
2445         case EOpIMod:
2446         case EOpIModAssign:
2447             if (isFloat)
2448                 writeBinaryOp = spirv::WriteFMod;
2449             else if (isUnsigned)
2450                 writeBinaryOp = spirv::WriteUMod;
2451             else
2452                 writeBinaryOp = spirv::WriteSMod;
2453             break;
2454 
2455         case EOpEqualComponentWise:
2456             if (isFloat)
2457                 writeBinaryOp = spirv::WriteFOrdEqual;
2458             else if (isBool)
2459                 writeBinaryOp = spirv::WriteLogicalEqual;
2460             else
2461                 writeBinaryOp = spirv::WriteIEqual;
2462             break;
2463         case EOpNotEqualComponentWise:
2464             if (isFloat)
2465                 writeBinaryOp = spirv::WriteFUnordNotEqual;
2466             else if (isBool)
2467                 writeBinaryOp = spirv::WriteLogicalNotEqual;
2468             else
2469                 writeBinaryOp = spirv::WriteINotEqual;
2470             break;
2471         case EOpLessThan:
2472         case EOpLessThanComponentWise:
2473             if (isFloat)
2474                 writeBinaryOp = spirv::WriteFOrdLessThan;
2475             else if (isUnsigned)
2476                 writeBinaryOp = spirv::WriteULessThan;
2477             else
2478                 writeBinaryOp = spirv::WriteSLessThan;
2479             break;
2480         case EOpGreaterThan:
2481         case EOpGreaterThanComponentWise:
2482             if (isFloat)
2483                 writeBinaryOp = spirv::WriteFOrdGreaterThan;
2484             else if (isUnsigned)
2485                 writeBinaryOp = spirv::WriteUGreaterThan;
2486             else
2487                 writeBinaryOp = spirv::WriteSGreaterThan;
2488             break;
2489         case EOpLessThanEqual:
2490         case EOpLessThanEqualComponentWise:
2491             if (isFloat)
2492                 writeBinaryOp = spirv::WriteFOrdLessThanEqual;
2493             else if (isUnsigned)
2494                 writeBinaryOp = spirv::WriteULessThanEqual;
2495             else
2496                 writeBinaryOp = spirv::WriteSLessThanEqual;
2497             break;
2498         case EOpGreaterThanEqual:
2499         case EOpGreaterThanEqualComponentWise:
2500             if (isFloat)
2501                 writeBinaryOp = spirv::WriteFOrdGreaterThanEqual;
2502             else if (isUnsigned)
2503                 writeBinaryOp = spirv::WriteUGreaterThanEqual;
2504             else
2505                 writeBinaryOp = spirv::WriteSGreaterThanEqual;
2506             break;
2507 
2508         case EOpVectorTimesScalar:
2509         case EOpVectorTimesScalarAssign:
2510             if (isFloat)
2511             {
2512                 writeBinaryOp        = spirv::WriteVectorTimesScalar;
2513                 binarySwapOperands   = node->getChildNode(1)->getAsTyped()->getType().isVector();
2514                 extendScalarToVector = false;
2515             }
2516             else
2517                 writeBinaryOp = spirv::WriteIMul;
2518             break;
2519         case EOpVectorTimesMatrix:
2520         case EOpVectorTimesMatrixAssign:
2521             writeBinaryOp    = spirv::WriteVectorTimesMatrix;
2522             operateOnColumns = false;
2523             break;
2524         case EOpMatrixTimesVector:
2525             writeBinaryOp    = spirv::WriteMatrixTimesVector;
2526             operateOnColumns = false;
2527             break;
2528         case EOpMatrixTimesScalar:
2529         case EOpMatrixTimesScalarAssign:
2530             writeBinaryOp        = spirv::WriteMatrixTimesScalar;
2531             binarySwapOperands   = secondChild->getType().isMatrix();
2532             operateOnColumns     = false;
2533             extendScalarToVector = false;
2534             break;
2535         case EOpMatrixTimesMatrix:
2536         case EOpMatrixTimesMatrixAssign:
2537             writeBinaryOp    = spirv::WriteMatrixTimesMatrix;
2538             operateOnColumns = false;
2539             break;
2540 
2541         case EOpLogicalOr:
2542             ASSERT(!IsShortCircuitNeeded(node));
2543             extendScalarToVector = false;
2544             writeBinaryOp        = spirv::WriteLogicalOr;
2545             break;
2546         case EOpLogicalXor:
2547             extendScalarToVector = false;
2548             writeBinaryOp        = spirv::WriteLogicalNotEqual;
2549             break;
2550         case EOpLogicalAnd:
2551             ASSERT(!IsShortCircuitNeeded(node));
2552             extendScalarToVector = false;
2553             writeBinaryOp        = spirv::WriteLogicalAnd;
2554             break;
2555 
2556         case EOpBitShiftLeft:
2557         case EOpBitShiftLeftAssign:
2558             writeBinaryOp = spirv::WriteShiftLeftLogical;
2559             break;
2560         case EOpBitShiftRight:
2561         case EOpBitShiftRightAssign:
2562             if (isUnsigned)
2563                 writeBinaryOp = spirv::WriteShiftRightLogical;
2564             else
2565                 writeBinaryOp = spirv::WriteShiftRightArithmetic;
2566             break;
2567         case EOpBitwiseAnd:
2568         case EOpBitwiseAndAssign:
2569             writeBinaryOp = spirv::WriteBitwiseAnd;
2570             break;
2571         case EOpBitwiseXor:
2572         case EOpBitwiseXorAssign:
2573             writeBinaryOp = spirv::WriteBitwiseXor;
2574             break;
2575         case EOpBitwiseOr:
2576         case EOpBitwiseOrAssign:
2577             writeBinaryOp = spirv::WriteBitwiseOr;
2578             break;
2579 
2580         case EOpRadians:
2581             extendedInst = spv::GLSLstd450Radians;
2582             break;
2583         case EOpDegrees:
2584             extendedInst = spv::GLSLstd450Degrees;
2585             break;
2586         case EOpSin:
2587             extendedInst = spv::GLSLstd450Sin;
2588             break;
2589         case EOpCos:
2590             extendedInst = spv::GLSLstd450Cos;
2591             break;
2592         case EOpTan:
2593             extendedInst = spv::GLSLstd450Tan;
2594             break;
2595         case EOpAsin:
2596             extendedInst = spv::GLSLstd450Asin;
2597             break;
2598         case EOpAcos:
2599             extendedInst = spv::GLSLstd450Acos;
2600             break;
2601         case EOpAtan:
2602             extendedInst = childCount == 1 ? spv::GLSLstd450Atan : spv::GLSLstd450Atan2;
2603             break;
2604         case EOpSinh:
2605             extendedInst = spv::GLSLstd450Sinh;
2606             break;
2607         case EOpCosh:
2608             extendedInst = spv::GLSLstd450Cosh;
2609             break;
2610         case EOpTanh:
2611             extendedInst = spv::GLSLstd450Tanh;
2612             break;
2613         case EOpAsinh:
2614             extendedInst = spv::GLSLstd450Asinh;
2615             break;
2616         case EOpAcosh:
2617             extendedInst = spv::GLSLstd450Acosh;
2618             break;
2619         case EOpAtanh:
2620             extendedInst = spv::GLSLstd450Atanh;
2621             break;
2622 
2623         case EOpPow:
2624             extendedInst = spv::GLSLstd450Pow;
2625             break;
2626         case EOpExp:
2627             extendedInst = spv::GLSLstd450Exp;
2628             break;
2629         case EOpLog:
2630             extendedInst = spv::GLSLstd450Log;
2631             break;
2632         case EOpExp2:
2633             extendedInst = spv::GLSLstd450Exp2;
2634             break;
2635         case EOpLog2:
2636             extendedInst = spv::GLSLstd450Log2;
2637             break;
2638         case EOpSqrt:
2639             extendedInst = spv::GLSLstd450Sqrt;
2640             break;
2641         case EOpInversesqrt:
2642             extendedInst = spv::GLSLstd450InverseSqrt;
2643             break;
2644 
2645         case EOpAbs:
2646             if (isFloat)
2647                 extendedInst = spv::GLSLstd450FAbs;
2648             else
2649                 extendedInst = spv::GLSLstd450SAbs;
2650             break;
2651         case EOpSign:
2652             if (isFloat)
2653                 extendedInst = spv::GLSLstd450FSign;
2654             else
2655                 extendedInst = spv::GLSLstd450SSign;
2656             break;
2657         case EOpFloor:
2658             extendedInst = spv::GLSLstd450Floor;
2659             break;
2660         case EOpTrunc:
2661             extendedInst = spv::GLSLstd450Trunc;
2662             break;
2663         case EOpRound:
2664             extendedInst = spv::GLSLstd450Round;
2665             break;
2666         case EOpRoundEven:
2667             extendedInst = spv::GLSLstd450RoundEven;
2668             break;
2669         case EOpCeil:
2670             extendedInst = spv::GLSLstd450Ceil;
2671             break;
2672         case EOpFract:
2673             extendedInst = spv::GLSLstd450Fract;
2674             break;
2675         case EOpMod:
2676             if (isFloat)
2677                 writeBinaryOp = spirv::WriteFMod;
2678             else if (isUnsigned)
2679                 writeBinaryOp = spirv::WriteUMod;
2680             else
2681                 writeBinaryOp = spirv::WriteSMod;
2682             break;
2683         case EOpMin:
2684             if (isFloat)
2685                 extendedInst = spv::GLSLstd450FMin;
2686             else if (isUnsigned)
2687                 extendedInst = spv::GLSLstd450UMin;
2688             else
2689                 extendedInst = spv::GLSLstd450SMin;
2690             break;
2691         case EOpMax:
2692             if (isFloat)
2693                 extendedInst = spv::GLSLstd450FMax;
2694             else if (isUnsigned)
2695                 extendedInst = spv::GLSLstd450UMax;
2696             else
2697                 extendedInst = spv::GLSLstd450SMax;
2698             break;
2699         case EOpClamp:
2700             if (isFloat)
2701                 extendedInst = spv::GLSLstd450FClamp;
2702             else if (isUnsigned)
2703                 extendedInst = spv::GLSLstd450UClamp;
2704             else
2705                 extendedInst = spv::GLSLstd450SClamp;
2706             break;
2707         case EOpMix:
2708             if (node->getChildNode(childCount - 1)->getAsTyped()->getType().getBasicType() ==
2709                 EbtBool)
2710             {
2711                 writeTernaryOp = spirv::WriteSelect;
2712             }
2713             else
2714             {
2715                 ASSERT(isFloat);
2716                 extendedInst = spv::GLSLstd450FMix;
2717             }
2718             break;
2719         case EOpStep:
2720             extendedInst = spv::GLSLstd450Step;
2721             break;
2722         case EOpSmoothstep:
2723             extendedInst = spv::GLSLstd450SmoothStep;
2724             break;
2725         case EOpModf:
2726             extendedInst = spv::GLSLstd450ModfStruct;
2727             lvalueCount  = 1;
2728             break;
2729         case EOpIsnan:
2730             writeUnaryOp = spirv::WriteIsNan;
2731             break;
2732         case EOpIsinf:
2733             writeUnaryOp = spirv::WriteIsInf;
2734             break;
2735         case EOpFloatBitsToInt:
2736         case EOpFloatBitsToUint:
2737         case EOpIntBitsToFloat:
2738         case EOpUintBitsToFloat:
2739             writeUnaryOp = spirv::WriteBitcast;
2740             break;
2741         case EOpFma:
2742             extendedInst = spv::GLSLstd450Fma;
2743             break;
2744         case EOpFrexp:
2745             extendedInst = spv::GLSLstd450FrexpStruct;
2746             lvalueCount  = 1;
2747             break;
2748         case EOpLdexp:
2749             extendedInst = spv::GLSLstd450Ldexp;
2750             break;
2751         case EOpPackSnorm2x16:
2752             extendedInst = spv::GLSLstd450PackSnorm2x16;
2753             break;
2754         case EOpPackUnorm2x16:
2755             extendedInst = spv::GLSLstd450PackUnorm2x16;
2756             break;
2757         case EOpPackHalf2x16:
2758             extendedInst = spv::GLSLstd450PackHalf2x16;
2759             break;
2760         case EOpUnpackSnorm2x16:
2761             extendedInst         = spv::GLSLstd450UnpackSnorm2x16;
2762             extendScalarToVector = false;
2763             break;
2764         case EOpUnpackUnorm2x16:
2765             extendedInst         = spv::GLSLstd450UnpackUnorm2x16;
2766             extendScalarToVector = false;
2767             break;
2768         case EOpUnpackHalf2x16:
2769             extendedInst         = spv::GLSLstd450UnpackHalf2x16;
2770             extendScalarToVector = false;
2771             break;
2772         case EOpPackUnorm4x8:
2773             extendedInst = spv::GLSLstd450PackUnorm4x8;
2774             break;
2775         case EOpPackSnorm4x8:
2776             extendedInst = spv::GLSLstd450PackSnorm4x8;
2777             break;
2778         case EOpUnpackUnorm4x8:
2779             extendedInst         = spv::GLSLstd450UnpackUnorm4x8;
2780             extendScalarToVector = false;
2781             break;
2782         case EOpUnpackSnorm4x8:
2783             extendedInst         = spv::GLSLstd450UnpackSnorm4x8;
2784             extendScalarToVector = false;
2785             break;
2786 
2787         case EOpLength:
2788             extendedInst = spv::GLSLstd450Length;
2789             break;
2790         case EOpDistance:
2791             extendedInst = spv::GLSLstd450Distance;
2792             break;
2793         case EOpDot:
2794             // Use normal multiplication for scalars.
2795             if (firstOperandType.isScalar())
2796             {
2797                 if (isFloat)
2798                     writeBinaryOp = spirv::WriteFMul;
2799                 else
2800                     writeBinaryOp = spirv::WriteIMul;
2801             }
2802             else
2803             {
2804                 writeBinaryOp = spirv::WriteDot;
2805             }
2806             break;
2807         case EOpCross:
2808             extendedInst = spv::GLSLstd450Cross;
2809             break;
2810         case EOpNormalize:
2811             extendedInst = spv::GLSLstd450Normalize;
2812             break;
2813         case EOpFaceforward:
2814             extendedInst = spv::GLSLstd450FaceForward;
2815             break;
2816         case EOpReflect:
2817             extendedInst = spv::GLSLstd450Reflect;
2818             break;
2819         case EOpRefract:
2820             extendedInst         = spv::GLSLstd450Refract;
2821             extendScalarToVector = false;
2822             break;
2823 
2824         case EOpOuterProduct:
2825             writeBinaryOp = spirv::WriteOuterProduct;
2826             break;
2827         case EOpTranspose:
2828             writeUnaryOp = spirv::WriteTranspose;
2829             break;
2830         case EOpDeterminant:
2831             extendedInst = spv::GLSLstd450Determinant;
2832             break;
2833         case EOpInverse:
2834             extendedInst = spv::GLSLstd450MatrixInverse;
2835             break;
2836 
2837         case EOpAny:
2838             writeUnaryOp = spirv::WriteAny;
2839             break;
2840         case EOpAll:
2841             writeUnaryOp = spirv::WriteAll;
2842             break;
2843 
2844         case EOpBitfieldExtract:
2845             if (isUnsigned)
2846                 writeTernaryOp = spirv::WriteBitFieldUExtract;
2847             else
2848                 writeTernaryOp = spirv::WriteBitFieldSExtract;
2849             break;
2850         case EOpBitfieldInsert:
2851             writeQuaternaryOp = spirv::WriteBitFieldInsert;
2852             break;
2853         case EOpBitfieldReverse:
2854             writeUnaryOp = spirv::WriteBitReverse;
2855             break;
2856         case EOpBitCount:
2857             writeUnaryOp = spirv::WriteBitCount;
2858             break;
2859         case EOpFindLSB:
2860             extendedInst = spv::GLSLstd450FindILsb;
2861             break;
2862         case EOpFindMSB:
2863             if (isUnsigned)
2864                 extendedInst = spv::GLSLstd450FindUMsb;
2865             else
2866                 extendedInst = spv::GLSLstd450FindSMsb;
2867             break;
2868         case EOpUaddCarry:
2869             writeBinaryOp = spirv::WriteIAddCarry;
2870             lvalueCount   = 1;
2871             break;
2872         case EOpUsubBorrow:
2873             writeBinaryOp = spirv::WriteISubBorrow;
2874             lvalueCount   = 1;
2875             break;
2876         case EOpUmulExtended:
2877             writeBinaryOp = spirv::WriteUMulExtended;
2878             lvalueCount   = 2;
2879             break;
2880         case EOpImulExtended:
2881             writeBinaryOp = spirv::WriteSMulExtended;
2882             lvalueCount   = 2;
2883             break;
2884 
2885         case EOpRgb_2_yuv:
2886         case EOpYuv_2_rgb:
2887             // These built-ins are emulated, and shouldn't be encountered at this point.
2888             UNREACHABLE();
2889             break;
2890 
2891         case EOpDFdx:
2892             writeUnaryOp = spirv::WriteDPdx;
2893             break;
2894         case EOpDFdy:
2895             writeUnaryOp = spirv::WriteDPdy;
2896             break;
2897         case EOpFwidth:
2898             writeUnaryOp = spirv::WriteFwidth;
2899             break;
2900 
2901         default:
2902             UNREACHABLE();
2903     }
2904 
2905     // Load the parameters.
2906     spirv::IdRefList parameterTypeIds;
2907     spirv::IdRefList parameters = loadAllParams(node, lvalueCount, &parameterTypeIds);
2908 
2909     if (isIncrementOrDecrement)
2910     {
2911         // ++ and -- are implemented with binary add and subtract, so add an implicit parameter with
2912         // size vecN(1).
2913         const uint8_t vecSize = firstOperandType.isMatrix() ? firstOperandType.getRows()
2914                                                             : firstOperandType.getNominalSize();
2915         const spirv::IdRef one =
2916             isFloat ? mBuilder.getVecConstant(1, vecSize) : mBuilder.getIvecConstant(1, vecSize);
2917         parameters.push_back(one);
2918     }
2919 
2920     const SpirvDecorations decorations =
2921         mBuilder.getArithmeticDecorations(node->getType(), node->isPrecise(), op);
2922     spirv::IdRef result;
2923     if (node->getType().getBasicType() != EbtVoid)
2924     {
2925         result = mBuilder.getNewId(decorations);
2926     }
2927 
2928     // In the case of modf, frexp, uaddCarry, usubBorrow, umulExtended and imulExtended, the SPIR-V
2929     // result is expected to be a struct instead.
2930     spirv::IdRef builtInResultTypeId = resultTypeId;
2931     spirv::IdRef builtInResult;
2932     if (lvalueCount > 0)
2933     {
2934         builtInResultTypeId = makeBuiltInOutputStructType(node, lvalueCount);
2935         builtInResult       = mBuilder.getNewId({});
2936     }
2937     else
2938     {
2939         builtInResult = result;
2940     }
2941 
2942     if (operateOnColumns)
2943     {
2944         // If negating a matrix, multiplying or comparing them, do that column by column.
2945         // Matrix-scalar operations (add, sub, mod etc) turn the scalar into a vector before
2946         // operating on the column.
2947         spirv::IdRefList columnIds;
2948 
2949         const SpirvDecorations operandDecorations = mBuilder.getDecorations(firstOperandType);
2950 
2951         const TType &matrixType =
2952             firstOperandType.isMatrix() ? firstOperandType : secondChild->getType();
2953 
2954         const spirv::IdRef columnTypeId =
2955             mBuilder.getBasicTypeId(matrixType.getBasicType(), matrixType.getRows());
2956 
2957         if (binarySwapOperands)
2958         {
2959             std::swap(parameters[0], parameters[1]);
2960         }
2961 
2962         if (extendScalarToVector)
2963         {
2964             extendScalarParamsToVector(node, columnTypeId, &parameters);
2965         }
2966 
2967         // Extract and apply the operator to each column.
2968         for (uint8_t columnIndex = 0; columnIndex < matrixType.getCols(); ++columnIndex)
2969         {
2970             spirv::IdRef columnIdA = parameters[0];
2971             if (firstOperandType.isMatrix())
2972             {
2973                 columnIdA = mBuilder.getNewId(operandDecorations);
2974                 spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2975                                              columnIdA, parameters[0],
2976                                              {spirv::LiteralInteger(columnIndex)});
2977             }
2978 
2979             columnIds.push_back(mBuilder.getNewId(decorations));
2980 
2981             if (writeUnaryOp)
2982             {
2983                 writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
2984                              columnIds.back(), columnIdA);
2985             }
2986             else
2987             {
2988                 ASSERT(writeBinaryOp);
2989 
2990                 spirv::IdRef columnIdB = parameters[1];
2991                 if (secondChild != nullptr && secondChild->getType().isMatrix())
2992                 {
2993                     columnIdB = mBuilder.getNewId(operandDecorations);
2994                     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(),
2995                                                  columnTypeId, columnIdB, parameters[1],
2996                                                  {spirv::LiteralInteger(columnIndex)});
2997                 }
2998 
2999                 writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), columnTypeId,
3000                               columnIds.back(), columnIdA, columnIdB);
3001             }
3002         }
3003 
3004         // Construct the result.
3005         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3006                                        builtInResult, columnIds);
3007     }
3008     else if (writeUnaryOp)
3009     {
3010         ASSERT(parameters.size() == 1);
3011         writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3012                      parameters[0]);
3013     }
3014     else if (writeBinaryOp)
3015     {
3016         ASSERT(parameters.size() == 2);
3017 
3018         if (extendScalarToVector)
3019         {
3020             extendScalarParamsToVector(node, builtInResultTypeId, &parameters);
3021         }
3022 
3023         if (binarySwapOperands)
3024         {
3025             std::swap(parameters[0], parameters[1]);
3026         }
3027 
3028         if (binaryInvertSecondParameter)
3029         {
3030             const spirv::IdRef one           = mBuilder.getFloatConstant(1);
3031             const spirv::IdRef invertedParam = mBuilder.getNewId(
3032                 mBuilder.getArithmeticDecorations(secondChild->getType(), node->isPrecise(), op));
3033             spirv::WriteFDiv(mBuilder.getSpirvCurrentFunctionBlock(), parameterTypeIds.back(),
3034                              invertedParam, one, parameters[1]);
3035             parameters[1] = invertedParam;
3036         }
3037 
3038         // Write the operation that combines the left and right values.
3039         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3040                       parameters[0], parameters[1]);
3041     }
3042     else if (writeTernaryOp)
3043     {
3044         ASSERT(parameters.size() == 3);
3045 
3046         // mix(a, b, bool) is the same as bool ? b : a;
3047         if (op == EOpMix)
3048         {
3049             std::swap(parameters[0], parameters[2]);
3050         }
3051 
3052         writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId, builtInResult,
3053                        parameters[0], parameters[1], parameters[2]);
3054     }
3055     else if (writeQuaternaryOp)
3056     {
3057         ASSERT(parameters.size() == 4);
3058 
3059         writeQuaternaryOp(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3060                           builtInResult, parameters[0], parameters[1], parameters[2],
3061                           parameters[3]);
3062     }
3063     else
3064     {
3065         // It's an extended instruction.
3066         ASSERT(extendedInst != spv::GLSLstd450Bad);
3067 
3068         if (extendScalarToVector)
3069         {
3070             extendScalarParamsToVector(node, builtInResultTypeId, &parameters);
3071         }
3072 
3073         spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), builtInResultTypeId,
3074                             builtInResult, mBuilder.getExtInstImportIdStd(),
3075                             spirv::LiteralExtInstInteger(extendedInst), parameters);
3076     }
3077 
3078     // If it's an assignment, store the calculated value.
3079     if (IsAssignment(node->getOp()))
3080     {
3081         accessChainStore(&mNodeData[mNodeData.size() - childCount], builtInResult,
3082                          firstOperandType);
3083     }
3084 
3085     // If the operation returns a struct, load the lsb and msb and store them in result/out
3086     // parameters.
3087     if (lvalueCount > 0)
3088     {
3089         storeBuiltInStructOutputInParamsAndReturnValue(node, lvalueCount, builtInResult, result,
3090                                                        resultTypeId);
3091     }
3092 
3093     // For post increment/decrement, return the value of the parameter itself as the result.
3094     if (op == EOpPostIncrement || op == EOpPostDecrement)
3095     {
3096         result = parameters[0];
3097     }
3098 
3099     return result;
3100 }
3101 
createCompare(TIntermOperator * node,spirv::IdRef resultTypeId)3102 spirv::IdRef OutputSPIRVTraverser::createCompare(TIntermOperator *node, spirv::IdRef resultTypeId)
3103 {
3104     const TOperator op       = node->getOp();
3105     TIntermTyped *operand    = node->getChildNode(0)->getAsTyped();
3106     const TType &operandType = operand->getType();
3107 
3108     const SpirvDecorations resultDecorations  = mBuilder.getDecorations(node->getType());
3109     const SpirvDecorations operandDecorations = mBuilder.getDecorations(operandType);
3110 
3111     // Load the left and right values.
3112     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3113     ASSERT(parameters.size() == 2);
3114 
3115     // In GLSL, operators == and != can operate on the following:
3116     //
3117     // - scalars: There's a SPIR-V instruction for this,
3118     // - vectors: The same SPIR-V instruction as scalars is used here, but the result is reduced
3119     //   with OpAll/OpAny for == and != respectively,
3120     // - matrices: Comparison must be done column by column and the result reduced,
3121     // - arrays: Comparison must be done on every array element and the result reduced,
3122     // - structs: Comparison must be done on each field and the result reduced.
3123     //
3124     // For the latter 3 cases, OpCompositeExtract is used to extract scalars and vectors out of the
3125     // more complex type, which is recursively traversed.  The results are accumulated in a list
3126     // that is then reduced 4 by 4 elements until a single boolean is produced.
3127 
3128     spirv::LiteralIntegerList currentAccessChain;
3129     spirv::IdRefList intermediateResults;
3130 
3131     createCompareImpl(op, operandType, resultTypeId, parameters[0], parameters[1],
3132                       operandDecorations, resultDecorations, &currentAccessChain,
3133                       &intermediateResults);
3134 
3135     // Make sure the function correctly pushes and pops access chain indices.
3136     ASSERT(currentAccessChain.empty());
3137 
3138     // Reduce the intermediate results.
3139     ASSERT(!intermediateResults.empty());
3140 
3141     // The following code implements this algorithm, assuming N bools are to be reduced:
3142     //
3143     //    Reduced           To Reduce
3144     //     {b1}           {b2, b3, ..., bN}      Initial state
3145     //                                           Loop
3146     //  {b1, b2, b3, b4}  {b5, b6, ..., bN}        Take up to 3 new bools
3147     //     {r1}           {b5, b6, ..., bN}        Reduce it
3148     //                                             Repeat
3149     //
3150     // In the end, a single value is left.
3151     size_t reducedCount       = 0;
3152     spirv::IdRefList toReduce = {intermediateResults[reducedCount++]};
3153     while (reducedCount < intermediateResults.size())
3154     {
3155         // Take up to 3 new bools.
3156         size_t toTakeCount = std::min<size_t>(3, intermediateResults.size() - reducedCount);
3157         for (size_t i = 0; i < toTakeCount; ++i)
3158         {
3159             toReduce.push_back(intermediateResults[reducedCount++]);
3160         }
3161 
3162         // Reduce them to one bool.
3163         const spirv::IdRef result = reduceBoolVector(op, toReduce, resultTypeId, resultDecorations);
3164 
3165         // Replace the list of bools to reduce with the reduced one.
3166         toReduce.clear();
3167         toReduce.push_back(result);
3168     }
3169 
3170     ASSERT(toReduce.size() == 1 && reducedCount == intermediateResults.size());
3171     return toReduce[0];
3172 }
3173 
createAtomicBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3174 spirv::IdRef OutputSPIRVTraverser::createAtomicBuiltIn(TIntermOperator *node,
3175                                                        spirv::IdRef resultTypeId)
3176 {
3177     const TType &operandType          = node->getChildNode(0)->getAsTyped()->getType();
3178     const TBasicType operandBasicType = operandType.getBasicType();
3179     const bool isImage                = IsImage(operandBasicType);
3180 
3181     // Most atomic instructions are in the form of:
3182     //
3183     //     %result = OpAtomicX %pointer Scope MemorySemantics %value
3184     //
3185     // OpAtomicCompareSwap is exceptionally different (note that compare and value are in different
3186     // order from GLSL):
3187     //
3188     //     %result = OpAtomicCompareExchange %pointer
3189     //                                       Scope MemorySemantics MemorySemantics
3190     //                                       %value %comparator
3191     //
3192     // In all cases, the first parameter is the pointer, and the rest are rvalues.
3193     //
3194     // For images, OpImageTexelPointer is used to form a pointer to the texel on which the atomic
3195     // operation is being performed.
3196     const size_t parameterCount       = node->getChildCount();
3197     size_t imagePointerParameterCount = 0;
3198     spirv::IdRef pointerId;
3199     spirv::IdRefList imagePointerParameters;
3200     spirv::IdRefList parameters;
3201 
3202     if (isImage)
3203     {
3204         // One parameter for coordinates.
3205         ++imagePointerParameterCount;
3206         if (IsImageMS(operandBasicType))
3207         {
3208             // One parameter for samples.
3209             ++imagePointerParameterCount;
3210         }
3211     }
3212 
3213     ASSERT(parameterCount >= 2 + imagePointerParameterCount);
3214 
3215     pointerId = accessChainCollapse(&mNodeData[mNodeData.size() - parameterCount]);
3216     for (size_t paramIndex = 1; paramIndex < parameterCount; ++paramIndex)
3217     {
3218         NodeData &param              = mNodeData[mNodeData.size() - parameterCount + paramIndex];
3219         const spirv::IdRef parameter = accessChainLoad(
3220             &param, node->getChildNode(paramIndex)->getAsTyped()->getType(), nullptr);
3221 
3222         // imageAtomic* built-ins have a few additional parameters right after the image.  These are
3223         // kept separately for use with OpImageTexelPointer.
3224         if (paramIndex <= imagePointerParameterCount)
3225         {
3226             imagePointerParameters.push_back(parameter);
3227         }
3228         else
3229         {
3230             parameters.push_back(parameter);
3231         }
3232     }
3233 
3234     // The scope of the operation is always Device as we don't enable the Vulkan memory model
3235     // extension.
3236     const spirv::IdScope scopeId = mBuilder.getUintConstant(spv::ScopeDevice);
3237 
3238     // The memory semantics is always relaxed as we don't enable the Vulkan memory model extension.
3239     const spirv::IdMemorySemantics semanticsId =
3240         mBuilder.getUintConstant(spv::MemorySemanticsMaskNone);
3241 
3242     WriteAtomicOp writeAtomicOp = nullptr;
3243 
3244     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3245 
3246     // Determine whether the operation is on ints or uints.
3247     const bool isUnsigned = isImage ? IsUIntImage(operandBasicType) : operandBasicType == EbtUInt;
3248 
3249     // For images, convert the pointer to the image to a pointer to a texel in the image.
3250     if (isImage)
3251     {
3252         const spirv::IdRef texelTypePointerId =
3253             mBuilder.getTypePointerId(resultTypeId, spv::StorageClassImage);
3254         const spirv::IdRef texelPointerId = mBuilder.getNewId({});
3255 
3256         const spirv::IdRef coordinate = imagePointerParameters[0];
3257         spirv::IdRef sample = imagePointerParameters.size() > 1 ? imagePointerParameters[1]
3258                                                                 : mBuilder.getUintConstant(0);
3259 
3260         spirv::WriteImageTexelPointer(mBuilder.getSpirvCurrentFunctionBlock(), texelTypePointerId,
3261                                       texelPointerId, pointerId, coordinate, sample);
3262 
3263         pointerId = texelPointerId;
3264     }
3265 
3266     switch (node->getOp())
3267     {
3268         case EOpAtomicAdd:
3269         case EOpImageAtomicAdd:
3270             writeAtomicOp = spirv::WriteAtomicIAdd;
3271             break;
3272         case EOpAtomicMin:
3273         case EOpImageAtomicMin:
3274             writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMin : spirv::WriteAtomicSMin;
3275             break;
3276         case EOpAtomicMax:
3277         case EOpImageAtomicMax:
3278             writeAtomicOp = isUnsigned ? spirv::WriteAtomicUMax : spirv::WriteAtomicSMax;
3279             break;
3280         case EOpAtomicAnd:
3281         case EOpImageAtomicAnd:
3282             writeAtomicOp = spirv::WriteAtomicAnd;
3283             break;
3284         case EOpAtomicOr:
3285         case EOpImageAtomicOr:
3286             writeAtomicOp = spirv::WriteAtomicOr;
3287             break;
3288         case EOpAtomicXor:
3289         case EOpImageAtomicXor:
3290             writeAtomicOp = spirv::WriteAtomicXor;
3291             break;
3292         case EOpAtomicExchange:
3293         case EOpImageAtomicExchange:
3294             writeAtomicOp = spirv::WriteAtomicExchange;
3295             break;
3296         case EOpAtomicCompSwap:
3297         case EOpImageAtomicCompSwap:
3298             // Generate this special instruction right here and early out.  Note again that the
3299             // value and compare parameters of OpAtomicCompareExchange are in the opposite order
3300             // from GLSL.
3301             ASSERT(parameters.size() == 2);
3302             spirv::WriteAtomicCompareExchange(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3303                                               result, pointerId, scopeId, semanticsId, semanticsId,
3304                                               parameters[1], parameters[0]);
3305             return result;
3306         default:
3307             UNREACHABLE();
3308     }
3309 
3310     // Write the instruction.
3311     ASSERT(parameters.size() == 1);
3312     writeAtomicOp(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, pointerId, scopeId,
3313                   semanticsId, parameters[0]);
3314 
3315     return result;
3316 }
3317 
createImageTextureBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)3318 spirv::IdRef OutputSPIRVTraverser::createImageTextureBuiltIn(TIntermOperator *node,
3319                                                              spirv::IdRef resultTypeId)
3320 {
3321     const TOperator op                = node->getOp();
3322     const TFunction *function         = node->getAsAggregate()->getFunction();
3323     const TType &samplerType          = function->getParam(0)->getType();
3324     const TBasicType samplerBasicType = samplerType.getBasicType();
3325 
3326     // Load the parameters.
3327     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
3328 
3329     // GLSL texture* and image* built-ins map to the following SPIR-V instructions.  Some of these
3330     // instructions take a "sampled image" while the others take the image itself.  In these
3331     // functions, the image, coordinates and Dref (for shadow sampling) are specified as positional
3332     // parameters while the rest are bundled in a list of image operands.
3333     //
3334     // Image operations that query:
3335     //
3336     // - OpImageQuerySizeLod
3337     // - OpImageQuerySize
3338     // - OpImageQueryLod <-- sampled image
3339     // - OpImageQueryLevels
3340     // - OpImageQuerySamples
3341     //
3342     // Image operations that read/write:
3343     //
3344     // - OpImageSampleImplicitLod <-- sampled image
3345     // - OpImageSampleExplicitLod <-- sampled image
3346     // - OpImageSampleDrefImplicitLod <-- sampled image
3347     // - OpImageSampleDrefExplicitLod <-- sampled image
3348     // - OpImageSampleProjImplicitLod <-- sampled image
3349     // - OpImageSampleProjExplicitLod <-- sampled image
3350     // - OpImageSampleProjDrefImplicitLod <-- sampled image
3351     // - OpImageSampleProjDrefExplicitLod <-- sampled image
3352     // - OpImageFetch
3353     // - OpImageGather <-- sampled image
3354     // - OpImageDrefGather <-- sampled image
3355     // - OpImageRead
3356     // - OpImageWrite
3357     //
3358     // The additional image parameters are:
3359     //
3360     // - Bias: Only used with ImplicitLod.
3361     // - Lod: Only used with ExplicitLod.
3362     // - Grad: 2x operands; dx and dy.  Only used with ExplicitLod.
3363     // - ConstOffset: Constant offset added to coordinates of OpImage*Gather.
3364     // - Offset: Non-constant offset added to coordinates of OpImage*Gather.
3365     // - ConstOffsets: Constant offsets added to coordinates of OpImage*Gather.
3366     // - Sample: Only used with OpImageFetch, OpImageRead and OpImageWrite.
3367     //
3368     // Where GLSL's built-in takes a sampler but SPIR-V expects an image, OpImage can be used to get
3369     // the SPIR-V image out of a SPIR-V sampled image.
3370 
3371     // The first parameter, which is either a sampled image or an image.  Some GLSL built-ins
3372     // receive a sampled image but their SPIR-V equivalent expects an image.  OpImage is used in
3373     // that case.
3374     spirv::IdRef image                = parameters[0];
3375     bool extractImageFromSampledImage = false;
3376 
3377     // The argument index for different possible parameters.  0 indicates that the argument is
3378     // unused.  Coordinates are usually at index 1, so it's pre-initialized.
3379     size_t coordinatesIndex     = 1;
3380     size_t biasIndex            = 0;
3381     size_t lodIndex             = 0;
3382     size_t compareIndex         = 0;
3383     size_t dPdxIndex            = 0;
3384     size_t dPdyIndex            = 0;
3385     size_t offsetIndex          = 0;
3386     size_t offsetsIndex         = 0;
3387     size_t gatherComponentIndex = 0;
3388     size_t sampleIndex          = 0;
3389     size_t dataIndex            = 0;
3390 
3391     // Whether this is a Dref variant of a sample call.
3392     bool isDref = IsShadowSampler(samplerBasicType);
3393     // Whether this is a Proj variant of a sample call.
3394     bool isProj = false;
3395 
3396     // The SPIR-V op used to implement the built-in.  For OpImageSample* instructions,
3397     // OpImageSampleImplicitLod is initially specified, which is later corrected based on |isDref|
3398     // and |isProj|.
3399     spv::Op spirvOp = BuiltInGroup::IsTexture(op) ? spv::OpImageSampleImplicitLod : spv::OpNop;
3400 
3401     // Organize the parameters and decide the SPIR-V Op to use.
3402     switch (op)
3403     {
3404         case EOpTexture2D:
3405         case EOpTextureCube:
3406         case EOpTexture3D:
3407         case EOpShadow2DEXT:
3408         case EOpTexture2DRect:
3409         case EOpTextureVideoWEBGL:
3410         case EOpTexture:
3411 
3412         case EOpTexture2DBias:
3413         case EOpTextureCubeBias:
3414         case EOpTexture3DBias:
3415         case EOpTextureBias:
3416 
3417             // For shadow cube arrays, the compare value is specified through an additional
3418             // parameter, while for the rest is taken out of the coordinates.
3419             if (function->getParamCount() == 3)
3420             {
3421                 if (samplerBasicType == EbtSamplerCubeArrayShadow)
3422                 {
3423                     compareIndex = 2;
3424                 }
3425                 else
3426                 {
3427                     biasIndex = 2;
3428                 }
3429             }
3430             else if (function->getParamCount() == 4 &&
3431                      samplerBasicType == EbtSamplerCubeArrayShadow)
3432             {
3433                 compareIndex = 2;
3434                 biasIndex    = 3;
3435             }
3436             break;
3437 
3438         case EOpTexture2DProj:
3439         case EOpTexture3DProj:
3440         case EOpShadow2DProjEXT:
3441         case EOpTexture2DRectProj:
3442         case EOpTextureProj:
3443 
3444         case EOpTexture2DProjBias:
3445         case EOpTexture3DProjBias:
3446         case EOpTextureProjBias:
3447 
3448             isProj = true;
3449             if (function->getParamCount() == 3)
3450             {
3451                 biasIndex = 2;
3452             }
3453             break;
3454 
3455         case EOpTexture3DLod:
3456 
3457         case EOpTexture2DLodVS:
3458         case EOpTextureCubeLodVS:
3459 
3460         case EOpTexture2DLodEXTFS:
3461         case EOpTextureCubeLodEXTFS:
3462         case EOpTextureLod:
3463 
3464             if (samplerBasicType == EbtSamplerCubeArrayShadow)
3465             {
3466                 ASSERT(function->getParamCount() == 4);
3467                 compareIndex = 2;
3468                 lodIndex     = 3;
3469             }
3470             else
3471             {
3472                 ASSERT(function->getParamCount() == 3);
3473                 lodIndex = 2;
3474             }
3475             break;
3476 
3477         case EOpTexture3DProjLod:
3478 
3479         case EOpTexture2DProjLodVS:
3480 
3481         case EOpTexture2DProjLodEXTFS:
3482         case EOpTextureProjLod:
3483 
3484             ASSERT(function->getParamCount() == 3);
3485             isProj   = true;
3486             lodIndex = 2;
3487             break;
3488 
3489         case EOpTexelFetch:
3490         case EOpTexelFetchOffset:
3491             // texelFetch has the following forms:
3492             //
3493             // - texelFetch(sampler, P);
3494             // - texelFetch(sampler, P, lod);
3495             // - texelFetch(samplerMS, P, sample);
3496             //
3497             // texelFetchOffset has an additional offset parameter at the end.
3498             //
3499             // In SPIR-V, OpImageFetch is used which operates on the image itself.
3500             spirvOp                      = spv::OpImageFetch;
3501             extractImageFromSampledImage = true;
3502 
3503             if (IsSamplerMS(samplerBasicType))
3504             {
3505                 ASSERT(function->getParamCount() == 3);
3506                 sampleIndex = 2;
3507             }
3508             else if (function->getParamCount() >= 3)
3509             {
3510                 lodIndex = 2;
3511             }
3512             if (op == EOpTexelFetchOffset)
3513             {
3514                 offsetIndex = function->getParamCount() - 1;
3515             }
3516             break;
3517 
3518         case EOpTexture2DGradEXT:
3519         case EOpTextureCubeGradEXT:
3520         case EOpTextureGrad:
3521 
3522             ASSERT(function->getParamCount() == 4);
3523             dPdxIndex = 2;
3524             dPdyIndex = 3;
3525             break;
3526 
3527         case EOpTexture2DProjGradEXT:
3528         case EOpTextureProjGrad:
3529 
3530             ASSERT(function->getParamCount() == 4);
3531             isProj    = true;
3532             dPdxIndex = 2;
3533             dPdyIndex = 3;
3534             break;
3535 
3536         case EOpTextureOffset:
3537         case EOpTextureOffsetBias:
3538 
3539             ASSERT(function->getParamCount() >= 3);
3540             offsetIndex = 2;
3541             if (function->getParamCount() == 4)
3542             {
3543                 biasIndex = 3;
3544             }
3545             break;
3546 
3547         case EOpTextureProjOffset:
3548         case EOpTextureProjOffsetBias:
3549 
3550             ASSERT(function->getParamCount() >= 3);
3551             isProj      = true;
3552             offsetIndex = 2;
3553             if (function->getParamCount() == 4)
3554             {
3555                 biasIndex = 3;
3556             }
3557             break;
3558 
3559         case EOpTextureLodOffset:
3560 
3561             ASSERT(function->getParamCount() == 4);
3562             lodIndex    = 2;
3563             offsetIndex = 3;
3564             break;
3565 
3566         case EOpTextureProjLodOffset:
3567 
3568             ASSERT(function->getParamCount() == 4);
3569             isProj      = true;
3570             lodIndex    = 2;
3571             offsetIndex = 3;
3572             break;
3573 
3574         case EOpTextureGradOffset:
3575 
3576             ASSERT(function->getParamCount() == 5);
3577             dPdxIndex   = 2;
3578             dPdyIndex   = 3;
3579             offsetIndex = 4;
3580             break;
3581 
3582         case EOpTextureProjGradOffset:
3583 
3584             ASSERT(function->getParamCount() == 5);
3585             isProj      = true;
3586             dPdxIndex   = 2;
3587             dPdyIndex   = 3;
3588             offsetIndex = 4;
3589             break;
3590 
3591         case EOpTextureGather:
3592 
3593             // For shadow textures, refZ (same as Dref) is specified as the last argument.
3594             // Otherwise a component may be specified which defaults to 0 if not specified.
3595             spirvOp = spv::OpImageGather;
3596             if (isDref)
3597             {
3598                 ASSERT(function->getParamCount() == 3);
3599                 compareIndex = 2;
3600             }
3601             else if (function->getParamCount() == 3)
3602             {
3603                 gatherComponentIndex = 2;
3604             }
3605             break;
3606 
3607         case EOpTextureGatherOffset:
3608         case EOpTextureGatherOffsetComp:
3609 
3610         case EOpTextureGatherOffsets:
3611         case EOpTextureGatherOffsetsComp:
3612 
3613             // textureGatherOffset and textureGatherOffsets have the following forms:
3614             //
3615             // - texelGatherOffset*(sampler, P, offset*);
3616             // - texelGatherOffset*(sampler, P, offset*, component);
3617             // - texelGatherOffset*(sampler, P, refZ, offset*);
3618             //
3619             spirvOp = spv::OpImageGather;
3620             if (isDref)
3621             {
3622                 ASSERT(function->getParamCount() == 4);
3623                 compareIndex = 2;
3624             }
3625             else if (function->getParamCount() == 4)
3626             {
3627                 gatherComponentIndex = 3;
3628             }
3629 
3630             ASSERT(function->getParamCount() >= 3);
3631             if (BuiltInGroup::IsTextureGatherOffset(op))
3632             {
3633                 offsetIndex = isDref ? 3 : 2;
3634             }
3635             else
3636             {
3637                 offsetsIndex = isDref ? 3 : 2;
3638             }
3639             break;
3640 
3641         case EOpImageStore:
3642             // imageStore has the following forms:
3643             //
3644             // - imageStore(image, P, data);
3645             // - imageStore(imageMS, P, sample, data);
3646             //
3647             spirvOp = spv::OpImageWrite;
3648             if (IsSamplerMS(samplerBasicType))
3649             {
3650                 ASSERT(function->getParamCount() == 4);
3651                 sampleIndex = 2;
3652                 dataIndex   = 3;
3653             }
3654             else
3655             {
3656                 ASSERT(function->getParamCount() == 3);
3657                 dataIndex = 2;
3658             }
3659             break;
3660 
3661         case EOpImageLoad:
3662             // imageStore has the following forms:
3663             //
3664             // - imageLoad(image, P);
3665             // - imageLoad(imageMS, P, sample);
3666             //
3667             spirvOp = spv::OpImageRead;
3668             if (IsSamplerMS(samplerBasicType))
3669             {
3670                 ASSERT(function->getParamCount() == 3);
3671                 sampleIndex = 2;
3672             }
3673             else
3674             {
3675                 ASSERT(function->getParamCount() == 2);
3676             }
3677             break;
3678 
3679             // Queries:
3680         case EOpTextureSize:
3681         case EOpImageSize:
3682             // textureSize has the following forms:
3683             //
3684             // - textureSize(sampler);
3685             // - textureSize(sampler, lod);
3686             //
3687             // while imageSize has only one form:
3688             //
3689             // - imageSize(image);
3690             //
3691             extractImageFromSampledImage = true;
3692             if (function->getParamCount() == 2)
3693             {
3694                 spirvOp  = spv::OpImageQuerySizeLod;
3695                 lodIndex = 1;
3696             }
3697             else
3698             {
3699                 spirvOp = spv::OpImageQuerySize;
3700             }
3701             // No coordinates parameter.
3702             coordinatesIndex = 0;
3703             // No dref parameter.
3704             isDref = false;
3705             break;
3706 
3707         default:
3708             UNREACHABLE();
3709     }
3710 
3711     // If an implicit-lod instruction is used outside a fragment shader, change that to an explicit
3712     // one as they are not allowed in SPIR-V outside fragment shaders.
3713     const bool noLodSupport = IsSamplerBuffer(samplerBasicType) ||
3714                               IsImageBuffer(samplerBasicType) || IsSamplerMS(samplerBasicType) ||
3715                               IsImageMS(samplerBasicType);
3716     const bool makeLodExplicit =
3717         mCompiler->getShaderType() != GL_FRAGMENT_SHADER && lodIndex == 0 && dPdxIndex == 0 &&
3718         !noLodSupport && (spirvOp == spv::OpImageSampleImplicitLod || spirvOp == spv::OpImageFetch);
3719 
3720     // Apply any necessary fix up.
3721 
3722     if (extractImageFromSampledImage && IsSampler(samplerBasicType))
3723     {
3724         // Get the (non-sampled) image type.
3725         SpirvType imageType = mBuilder.getSpirvType(samplerType, {});
3726         ASSERT(!imageType.isSamplerBaseImage);
3727         imageType.isSamplerBaseImage            = true;
3728         const spirv::IdRef extractedImageTypeId = mBuilder.getSpirvTypeData(imageType, nullptr).id;
3729 
3730         // Use OpImage to get the image out of the sampled image.
3731         const spirv::IdRef extractedImage = mBuilder.getNewId({});
3732         spirv::WriteImage(mBuilder.getSpirvCurrentFunctionBlock(), extractedImageTypeId,
3733                           extractedImage, image);
3734         image = extractedImage;
3735     }
3736 
3737     // Gather operands as necessary.
3738 
3739     // - Coordinates
3740     uint8_t coordinatesChannelCount = 0;
3741     spirv::IdRef coordinatesId;
3742     const TType *coordinatesType = nullptr;
3743     if (coordinatesIndex > 0)
3744     {
3745         coordinatesId           = parameters[coordinatesIndex];
3746         coordinatesType         = &node->getChildNode(coordinatesIndex)->getAsTyped()->getType();
3747         coordinatesChannelCount = coordinatesType->getNominalSize();
3748     }
3749 
3750     // - Dref; either specified as a compare/refz argument (cube array, gather), or:
3751     //   * coordinates.z for proj variants
3752     //   * coordinates.<last> for others
3753     spirv::IdRef drefId;
3754     if (compareIndex > 0)
3755     {
3756         drefId = parameters[compareIndex];
3757     }
3758     else if (isDref)
3759     {
3760         // Get the component index
3761         ASSERT(coordinatesChannelCount > 0);
3762         uint8_t drefComponent = isProj ? 2 : coordinatesChannelCount - 1;
3763 
3764         // Get the component type
3765         SpirvType drefSpirvType       = mBuilder.getSpirvType(*coordinatesType, {});
3766         drefSpirvType.primarySize     = 1;
3767         const spirv::IdRef drefTypeId = mBuilder.getSpirvTypeData(drefSpirvType, nullptr).id;
3768 
3769         // Extract the dref component out of coordinates.
3770         drefId = mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3771         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), drefTypeId, drefId,
3772                                      coordinatesId, {spirv::LiteralInteger(drefComponent)});
3773     }
3774 
3775     // - Gather component
3776     spirv::IdRef gatherComponentId;
3777     if (gatherComponentIndex > 0)
3778     {
3779         gatherComponentId = parameters[gatherComponentIndex];
3780     }
3781     else if (spirvOp == spv::OpImageGather)
3782     {
3783         // If comp is not specified, component 0 is taken as default.
3784         gatherComponentId = mBuilder.getIntConstant(0);
3785     }
3786 
3787     // - Image write data
3788     spirv::IdRef dataId;
3789     if (dataIndex > 0)
3790     {
3791         dataId = parameters[dataIndex];
3792     }
3793 
3794     // - Other operands
3795     spv::ImageOperandsMask operandsMask = spv::ImageOperandsMaskNone;
3796     spirv::IdRefList imageOperandsList;
3797 
3798     if (biasIndex > 0)
3799     {
3800         operandsMask = operandsMask | spv::ImageOperandsBiasMask;
3801         imageOperandsList.push_back(parameters[biasIndex]);
3802     }
3803     if (lodIndex > 0)
3804     {
3805         operandsMask = operandsMask | spv::ImageOperandsLodMask;
3806         imageOperandsList.push_back(parameters[lodIndex]);
3807     }
3808     else if (makeLodExplicit)
3809     {
3810         // If the implicit-lod variant is used outside fragment shaders, switch to explicit and use
3811         // lod 0.
3812         operandsMask = operandsMask | spv::ImageOperandsLodMask;
3813         imageOperandsList.push_back(spirvOp == spv::OpImageFetch ? mBuilder.getUintConstant(0)
3814                                                                  : mBuilder.getFloatConstant(0));
3815     }
3816     if (dPdxIndex > 0)
3817     {
3818         ASSERT(dPdyIndex > 0);
3819         operandsMask = operandsMask | spv::ImageOperandsGradMask;
3820         imageOperandsList.push_back(parameters[dPdxIndex]);
3821         imageOperandsList.push_back(parameters[dPdyIndex]);
3822     }
3823     if (offsetIndex > 0)
3824     {
3825         // Non-const offsets require the ImageGatherExtended feature.
3826         if (node->getChildNode(offsetIndex)->getAsTyped()->hasConstantValue())
3827         {
3828             operandsMask = operandsMask | spv::ImageOperandsConstOffsetMask;
3829         }
3830         else
3831         {
3832             ASSERT(spirvOp == spv::OpImageGather);
3833 
3834             operandsMask = operandsMask | spv::ImageOperandsOffsetMask;
3835             mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3836         }
3837         imageOperandsList.push_back(parameters[offsetIndex]);
3838     }
3839     if (offsetsIndex > 0)
3840     {
3841         ASSERT(node->getChildNode(offsetsIndex)->getAsTyped()->hasConstantValue());
3842 
3843         operandsMask = operandsMask | spv::ImageOperandsConstOffsetsMask;
3844         mBuilder.addCapability(spv::CapabilityImageGatherExtended);
3845         imageOperandsList.push_back(parameters[offsetsIndex]);
3846     }
3847     if (sampleIndex > 0)
3848     {
3849         operandsMask = operandsMask | spv::ImageOperandsSampleMask;
3850         imageOperandsList.push_back(parameters[sampleIndex]);
3851     }
3852 
3853     const spv::ImageOperandsMask *imageOperands =
3854         imageOperandsList.empty() ? nullptr : &operandsMask;
3855 
3856     // GLSL and SPIR-V are different in the way the projective component is specified:
3857     //
3858     // In GLSL:
3859     //
3860     // > The texture coordinates consumed from P, not including the last component of P, are divided
3861     // > by the last component of P.
3862     //
3863     // In SPIR-V, there's a similar language (division by last element), but with the following
3864     // added:
3865     //
3866     // > ... all unused components will appear after all used components.
3867     //
3868     // So for example for textureProj(sampler, vec4 P), the projective coordinates are P.xy/P.w,
3869     // where P.z is ignored.  In SPIR-V instead that would be P.xy/P.z and P.w is ignored.
3870     //
3871     if (isProj)
3872     {
3873         uint8_t requiredChannelCount = coordinatesChannelCount;
3874         // texture*Proj* operate on the following parameters:
3875         //
3876         // - sampler2D, vec3 P
3877         // - sampler2D, vec4 P
3878         // - sampler2DRect, vec3 P
3879         // - sampler2DRect, vec4 P
3880         // - sampler3D, vec4 P
3881         // - sampler2DShadow, vec4 P
3882         // - sampler2DRectShadow, vec4 P
3883         //
3884         // Of these cases, only (sampler2D*, vec4 P) requires moving the proj channel from .w to the
3885         // appropriate location (.y for 1D and .z for 2D).
3886         if (IsSampler2D(samplerBasicType))
3887         {
3888             requiredChannelCount = 3;
3889         }
3890         if (requiredChannelCount != coordinatesChannelCount)
3891         {
3892             ASSERT(coordinatesChannelCount == 4);
3893 
3894             // Get the component type
3895             SpirvType spirvType                  = mBuilder.getSpirvType(*coordinatesType, {});
3896             const spirv::IdRef coordinatesTypeId = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3897             spirvType.primarySize                = 1;
3898             const spirv::IdRef channelTypeId     = mBuilder.getSpirvTypeData(spirvType, nullptr).id;
3899 
3900             // Extract the last component out of coordinates.
3901             const spirv::IdRef projChannelId =
3902                 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3903             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), channelTypeId,
3904                                          projChannelId, coordinatesId,
3905                                          {spirv::LiteralInteger(coordinatesChannelCount - 1)});
3906 
3907             // Insert it after the channels that are consumed.  The extra channels are ignored per
3908             // the SPIR-V spec.
3909             const spirv::IdRef newCoordinatesId =
3910                 mBuilder.getNewId(mBuilder.getDecorations(*coordinatesType));
3911             spirv::WriteCompositeInsert(mBuilder.getSpirvCurrentFunctionBlock(), coordinatesTypeId,
3912                                         newCoordinatesId, projChannelId, coordinatesId,
3913                                         {spirv::LiteralInteger(requiredChannelCount - 1)});
3914             coordinatesId = newCoordinatesId;
3915         }
3916     }
3917 
3918     // Select the correct sample Op based on whether the Proj, Dref or Explicit variants are used.
3919     if (spirvOp == spv::OpImageSampleImplicitLod)
3920     {
3921         ASSERT(!noLodSupport);
3922         const bool isExplicitLod = lodIndex != 0 || makeLodExplicit || dPdxIndex != 0;
3923         if (isDref)
3924         {
3925             if (isProj)
3926             {
3927                 spirvOp = isExplicitLod ? spv::OpImageSampleProjDrefExplicitLod
3928                                         : spv::OpImageSampleProjDrefImplicitLod;
3929             }
3930             else
3931             {
3932                 spirvOp = isExplicitLod ? spv::OpImageSampleDrefExplicitLod
3933                                         : spv::OpImageSampleDrefImplicitLod;
3934             }
3935         }
3936         else
3937         {
3938             if (isProj)
3939             {
3940                 spirvOp = isExplicitLod ? spv::OpImageSampleProjExplicitLod
3941                                         : spv::OpImageSampleProjImplicitLod;
3942             }
3943             else
3944             {
3945                 spirvOp =
3946                     isExplicitLod ? spv::OpImageSampleExplicitLod : spv::OpImageSampleImplicitLod;
3947             }
3948         }
3949     }
3950     if (spirvOp == spv::OpImageGather && isDref)
3951     {
3952         spirvOp = spv::OpImageDrefGather;
3953     }
3954 
3955     spirv::IdRef result;
3956     if (spirvOp != spv::OpImageWrite)
3957     {
3958         result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
3959     }
3960 
3961     switch (spirvOp)
3962     {
3963         case spv::OpImageQuerySizeLod:
3964             mBuilder.addCapability(spv::CapabilityImageQuery);
3965             ASSERT(imageOperandsList.size() == 1);
3966             spirv::WriteImageQuerySizeLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3967                                           result, image, imageOperandsList[0]);
3968             break;
3969         case spv::OpImageQuerySize:
3970             mBuilder.addCapability(spv::CapabilityImageQuery);
3971             spirv::WriteImageQuerySize(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3972                                        result, image);
3973             break;
3974         case spv::OpImageQueryLod:
3975             mBuilder.addCapability(spv::CapabilityImageQuery);
3976             spirv::WriteImageQueryLod(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
3977                                       image, coordinatesId);
3978             break;
3979         case spv::OpImageQueryLevels:
3980             mBuilder.addCapability(spv::CapabilityImageQuery);
3981             spirv::WriteImageQueryLevels(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3982                                          result, image);
3983             break;
3984         case spv::OpImageQuerySamples:
3985             mBuilder.addCapability(spv::CapabilityImageQuery);
3986             spirv::WriteImageQuerySamples(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
3987                                           result, image);
3988             break;
3989         case spv::OpImageSampleImplicitLod:
3990             spirv::WriteImageSampleImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3991                                                resultTypeId, result, image, coordinatesId,
3992                                                imageOperands, imageOperandsList);
3993             break;
3994         case spv::OpImageSampleExplicitLod:
3995             spirv::WriteImageSampleExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
3996                                                resultTypeId, result, image, coordinatesId,
3997                                                *imageOperands, imageOperandsList);
3998             break;
3999         case spv::OpImageSampleDrefImplicitLod:
4000             spirv::WriteImageSampleDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4001                                                    resultTypeId, result, image, coordinatesId,
4002                                                    drefId, imageOperands, imageOperandsList);
4003             break;
4004         case spv::OpImageSampleDrefExplicitLod:
4005             spirv::WriteImageSampleDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4006                                                    resultTypeId, result, image, coordinatesId,
4007                                                    drefId, *imageOperands, imageOperandsList);
4008             break;
4009         case spv::OpImageSampleProjImplicitLod:
4010             spirv::WriteImageSampleProjImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4011                                                    resultTypeId, result, image, coordinatesId,
4012                                                    imageOperands, imageOperandsList);
4013             break;
4014         case spv::OpImageSampleProjExplicitLod:
4015             spirv::WriteImageSampleProjExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4016                                                    resultTypeId, result, image, coordinatesId,
4017                                                    *imageOperands, imageOperandsList);
4018             break;
4019         case spv::OpImageSampleProjDrefImplicitLod:
4020             spirv::WriteImageSampleProjDrefImplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4021                                                        resultTypeId, result, image, coordinatesId,
4022                                                        drefId, imageOperands, imageOperandsList);
4023             break;
4024         case spv::OpImageSampleProjDrefExplicitLod:
4025             spirv::WriteImageSampleProjDrefExplicitLod(mBuilder.getSpirvCurrentFunctionBlock(),
4026                                                        resultTypeId, result, image, coordinatesId,
4027                                                        drefId, *imageOperands, imageOperandsList);
4028             break;
4029         case spv::OpImageFetch:
4030             spirv::WriteImageFetch(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4031                                    image, coordinatesId, imageOperands, imageOperandsList);
4032             break;
4033         case spv::OpImageGather:
4034             spirv::WriteImageGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4035                                     image, coordinatesId, gatherComponentId, imageOperands,
4036                                     imageOperandsList);
4037             break;
4038         case spv::OpImageDrefGather:
4039             spirv::WriteImageDrefGather(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId,
4040                                         result, image, coordinatesId, drefId, imageOperands,
4041                                         imageOperandsList);
4042             break;
4043         case spv::OpImageRead:
4044             spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4045                                   image, coordinatesId, imageOperands, imageOperandsList);
4046             break;
4047         case spv::OpImageWrite:
4048             spirv::WriteImageWrite(mBuilder.getSpirvCurrentFunctionBlock(), image, coordinatesId,
4049                                    dataId, imageOperands, imageOperandsList);
4050             break;
4051         default:
4052             UNREACHABLE();
4053     }
4054 
4055     return result;
4056 }
4057 
createSubpassLoadBuiltIn(TIntermOperator * node,spirv::IdRef resultTypeId)4058 spirv::IdRef OutputSPIRVTraverser::createSubpassLoadBuiltIn(TIntermOperator *node,
4059                                                             spirv::IdRef resultTypeId)
4060 {
4061     // Load the parameters.
4062     spirv::IdRefList parameters = loadAllParams(node, 0, nullptr);
4063     const spirv::IdRef image    = parameters[0];
4064 
4065     // If multisampled, an additional parameter specifies the sample.  This is passed through as an
4066     // extra image operand.
4067     const bool hasSampleParam = parameters.size() == 2;
4068     const spv::ImageOperandsMask operandsMask =
4069         hasSampleParam ? spv::ImageOperandsSampleMask : spv::ImageOperandsMaskNone;
4070     spirv::IdRefList imageOperandsList;
4071     if (hasSampleParam)
4072     {
4073         imageOperandsList.push_back(parameters[1]);
4074     }
4075 
4076     // |subpassLoad| is implemented with OpImageRead.  This OP takes a coordinate, which is unused
4077     // and is set to (0, 0) here.
4078     const spirv::IdRef coordId = mBuilder.getNullConstant(mBuilder.getBasicTypeId(EbtUInt, 2));
4079 
4080     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4081     spirv::WriteImageRead(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result, image,
4082                           coordId, hasSampleParam ? &operandsMask : nullptr, imageOperandsList);
4083 
4084     return result;
4085 }
4086 
createInterpolate(TIntermOperator * node,spirv::IdRef resultTypeId)4087 spirv::IdRef OutputSPIRVTraverser::createInterpolate(TIntermOperator *node,
4088                                                      spirv::IdRef resultTypeId)
4089 {
4090     spv::GLSLstd450 extendedInst = spv::GLSLstd450Bad;
4091 
4092     mBuilder.addCapability(spv::CapabilityInterpolationFunction);
4093 
4094     switch (node->getOp())
4095     {
4096         case EOpInterpolateAtCentroid:
4097             extendedInst = spv::GLSLstd450InterpolateAtCentroid;
4098             break;
4099         case EOpInterpolateAtSample:
4100             extendedInst = spv::GLSLstd450InterpolateAtSample;
4101             break;
4102         case EOpInterpolateAtOffset:
4103             extendedInst = spv::GLSLstd450InterpolateAtOffset;
4104             break;
4105         default:
4106             UNREACHABLE();
4107     }
4108 
4109     size_t childCount = node->getChildCount();
4110 
4111     spirv::IdRefList parameters;
4112 
4113     // interpolateAt* takes the interpolant as the first argument, *pointer* to which needs to be
4114     // passed to the instruction.  Except interpolateAtCentroid, another parameter follows.
4115     parameters.push_back(accessChainCollapse(&mNodeData[mNodeData.size() - childCount]));
4116     if (childCount > 1)
4117     {
4118         parameters.push_back(accessChainLoad(
4119             &mNodeData.back(), node->getChildNode(1)->getAsTyped()->getType(), nullptr));
4120     }
4121 
4122     const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
4123 
4124     spirv::WriteExtInst(mBuilder.getSpirvCurrentFunctionBlock(), resultTypeId, result,
4125                         mBuilder.getExtInstImportIdStd(),
4126                         spirv::LiteralExtInstInteger(extendedInst), parameters);
4127 
4128     return result;
4129 }
4130 
castBasicType(spirv::IdRef value,const TType & valueType,const TType & expectedType,spirv::IdRef * resultTypeIdOut)4131 spirv::IdRef OutputSPIRVTraverser::castBasicType(spirv::IdRef value,
4132                                                  const TType &valueType,
4133                                                  const TType &expectedType,
4134                                                  spirv::IdRef *resultTypeIdOut)
4135 {
4136     const TBasicType expectedBasicType = expectedType.getBasicType();
4137     if (valueType.getBasicType() == expectedBasicType)
4138     {
4139         return value;
4140     }
4141 
4142     // Make sure no attempt is made to cast a matrix to int/uint.
4143     ASSERT(!valueType.isMatrix() || expectedBasicType == EbtFloat);
4144 
4145     SpirvType valueSpirvType                            = mBuilder.getSpirvType(valueType, {});
4146     valueSpirvType.type                                 = expectedBasicType;
4147     valueSpirvType.typeSpec.isOrHasBoolInInterfaceBlock = false;
4148     const spirv::IdRef castTypeId = mBuilder.getSpirvTypeData(valueSpirvType, nullptr).id;
4149 
4150     const spirv::IdRef castValue = mBuilder.getNewId(mBuilder.getDecorations(expectedType));
4151 
4152     // Write the instruction that casts between types.  Different instructions are used based on the
4153     // types being converted.
4154     //
4155     // - int/uint <-> float: OpConvert*To*
4156     // - int <-> uint: OpBitcast
4157     // - bool --> int/uint/float: OpSelect with 0 and 1
4158     // - int/uint --> bool: OPINotEqual 0
4159     // - float --> bool: OpFUnordNotEqual 0
4160 
4161     WriteUnaryOp writeUnaryOp     = nullptr;
4162     WriteBinaryOp writeBinaryOp   = nullptr;
4163     WriteTernaryOp writeTernaryOp = nullptr;
4164 
4165     spirv::IdRef zero;
4166     spirv::IdRef one;
4167 
4168     switch (valueType.getBasicType())
4169     {
4170         case EbtFloat:
4171             switch (expectedBasicType)
4172             {
4173                 case EbtInt:
4174                     writeUnaryOp = spirv::WriteConvertFToS;
4175                     break;
4176                 case EbtUInt:
4177                     writeUnaryOp = spirv::WriteConvertFToU;
4178                     break;
4179                 case EbtBool:
4180                     zero          = mBuilder.getVecConstant(0, valueType.getNominalSize());
4181                     writeBinaryOp = spirv::WriteFUnordNotEqual;
4182                     break;
4183                 default:
4184                     UNREACHABLE();
4185             }
4186             break;
4187 
4188         case EbtInt:
4189         case EbtUInt:
4190             switch (expectedBasicType)
4191             {
4192                 case EbtFloat:
4193                     writeUnaryOp = valueType.getBasicType() == EbtInt ? spirv::WriteConvertSToF
4194                                                                       : spirv::WriteConvertUToF;
4195                     break;
4196                 case EbtInt:
4197                 case EbtUInt:
4198                     writeUnaryOp = spirv::WriteBitcast;
4199                     break;
4200                 case EbtBool:
4201                     zero          = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4202                     writeBinaryOp = spirv::WriteINotEqual;
4203                     break;
4204                 default:
4205                     UNREACHABLE();
4206             }
4207             break;
4208 
4209         case EbtBool:
4210             writeTernaryOp = spirv::WriteSelect;
4211             switch (expectedBasicType)
4212             {
4213                 case EbtFloat:
4214                     zero = mBuilder.getVecConstant(0, valueType.getNominalSize());
4215                     one  = mBuilder.getVecConstant(1, valueType.getNominalSize());
4216                     break;
4217                 case EbtInt:
4218                     zero = mBuilder.getIvecConstant(0, valueType.getNominalSize());
4219                     one  = mBuilder.getIvecConstant(1, valueType.getNominalSize());
4220                     break;
4221                 case EbtUInt:
4222                     zero = mBuilder.getUvecConstant(0, valueType.getNominalSize());
4223                     one  = mBuilder.getUvecConstant(1, valueType.getNominalSize());
4224                     break;
4225                 default:
4226                     UNREACHABLE();
4227             }
4228             break;
4229 
4230         default:
4231             UNREACHABLE();
4232     }
4233 
4234     if (writeUnaryOp)
4235     {
4236         writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value);
4237     }
4238     else if (writeBinaryOp)
4239     {
4240         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, zero);
4241     }
4242     else
4243     {
4244         ASSERT(writeTernaryOp);
4245         writeTernaryOp(mBuilder.getSpirvCurrentFunctionBlock(), castTypeId, castValue, value, one,
4246                        zero);
4247     }
4248 
4249     if (resultTypeIdOut)
4250     {
4251         *resultTypeIdOut = castTypeId;
4252     }
4253 
4254     return castValue;
4255 }
4256 
cast(spirv::IdRef value,const TType & valueType,const SpirvTypeSpec & valueTypeSpec,const SpirvTypeSpec & expectedTypeSpec,spirv::IdRef * resultTypeIdOut)4257 spirv::IdRef OutputSPIRVTraverser::cast(spirv::IdRef value,
4258                                         const TType &valueType,
4259                                         const SpirvTypeSpec &valueTypeSpec,
4260                                         const SpirvTypeSpec &expectedTypeSpec,
4261                                         spirv::IdRef *resultTypeIdOut)
4262 {
4263     // If there's no difference in type specialization, there's nothing to cast.
4264     if (valueTypeSpec.blockStorage == expectedTypeSpec.blockStorage &&
4265         valueTypeSpec.isInvariantBlock == expectedTypeSpec.isInvariantBlock &&
4266         valueTypeSpec.isRowMajorQualifiedBlock == expectedTypeSpec.isRowMajorQualifiedBlock &&
4267         valueTypeSpec.isRowMajorQualifiedArray == expectedTypeSpec.isRowMajorQualifiedArray &&
4268         valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock &&
4269         valueTypeSpec.isPatchIOBlock == expectedTypeSpec.isPatchIOBlock)
4270     {
4271         return value;
4272     }
4273 
4274     // At this point, a value is loaded with the |valueType| GLSL type which is of a SPIR-V type
4275     // specialized by |valueTypeSpec|.  However, it's being assigned (for example through operator=,
4276     // used in a constructor or passed as a function argument) where the same GLSL type is expected
4277     // but with different SPIR-V type specialization (|expectedTypeSpec|).
4278     //
4279     // If SPIR-V 1.4 is available, use OpCopyLogical if possible.  OpCopyLogical works on arrays and
4280     // structs, and only if the types are logically the same.  This means that arrays and structs
4281     // can be copied with this instruction despite their SpirvTypeSpec being different.  The only
4282     // exception is if there is a mismatch in the isOrHasBoolInInterfaceBlock type specialization
4283     // as it actually changes the type of the struct members.
4284     if (mCompileOptions.emitSPIRV14 && (valueType.isArray() || valueType.getStruct() != nullptr) &&
4285         valueTypeSpec.isOrHasBoolInInterfaceBlock == expectedTypeSpec.isOrHasBoolInInterfaceBlock)
4286     {
4287         const spirv::IdRef expectedTypeId =
4288             mBuilder.getTypeDataOverrideTypeSpec(valueType, expectedTypeSpec).id;
4289         const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
4290 
4291         spirv::WriteCopyLogical(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId, expectedId,
4292                                 value);
4293         if (resultTypeIdOut)
4294         {
4295             *resultTypeIdOut = expectedTypeId;
4296         }
4297         return expectedId;
4298     }
4299 
4300     // The following code recursively copies the array elements or struct fields and then constructs
4301     // the final result with the expected SPIR-V type.
4302 
4303     // Interface blocks cannot be copied or passed as parameters in GLSL.
4304     ASSERT(!valueType.isInterfaceBlock());
4305 
4306     spirv::IdRefList constituents;
4307 
4308     if (valueType.isArray())
4309     {
4310         // Find the SPIR-V type specialization for the element type.
4311         SpirvTypeSpec valueElementTypeSpec    = valueTypeSpec;
4312         SpirvTypeSpec expectedElementTypeSpec = expectedTypeSpec;
4313 
4314         const bool isElementBlock = valueType.getStruct() != nullptr;
4315         const bool isElementArray = valueType.isArrayOfArrays();
4316 
4317         valueElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4318         expectedElementTypeSpec.onArrayElementSelection(isElementBlock, isElementArray);
4319 
4320         // Get the element type id.
4321         TType elementType(valueType);
4322         elementType.toArrayElementType();
4323 
4324         const spirv::IdRef elementTypeId =
4325             mBuilder.getTypeDataOverrideTypeSpec(elementType, valueElementTypeSpec).id;
4326 
4327         const SpirvDecorations elementDecorations = mBuilder.getDecorations(elementType);
4328 
4329         // Extract each element of the array and cast it to the expected type.
4330         for (unsigned int elementIndex = 0; elementIndex < valueType.getOutermostArraySize();
4331              ++elementIndex)
4332         {
4333             const spirv::IdRef elementId = mBuilder.getNewId(elementDecorations);
4334             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), elementTypeId,
4335                                          elementId, value, {spirv::LiteralInteger(elementIndex)});
4336 
4337             constituents.push_back(cast(elementId, elementType, valueElementTypeSpec,
4338                                         expectedElementTypeSpec, nullptr));
4339         }
4340     }
4341     else if (valueType.getStruct() != nullptr)
4342     {
4343         uint32_t fieldIndex = 0;
4344 
4345         // Extract each field of the struct and cast it to the expected type.
4346         for (const TField *field : valueType.getStruct()->fields())
4347         {
4348             const TType &fieldType = *field->type();
4349 
4350             // Find the SPIR-V type specialization for the field type.
4351             SpirvTypeSpec valueFieldTypeSpec    = valueTypeSpec;
4352             SpirvTypeSpec expectedFieldTypeSpec = expectedTypeSpec;
4353 
4354             valueFieldTypeSpec.onBlockFieldSelection(fieldType);
4355             expectedFieldTypeSpec.onBlockFieldSelection(fieldType);
4356 
4357             // Get the field type id.
4358             const spirv::IdRef fieldTypeId =
4359                 mBuilder.getTypeDataOverrideTypeSpec(fieldType, valueFieldTypeSpec).id;
4360 
4361             // Extract the field.
4362             const spirv::IdRef fieldId = mBuilder.getNewId(mBuilder.getDecorations(fieldType));
4363             spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId,
4364                                          fieldId, value, {spirv::LiteralInteger(fieldIndex++)});
4365 
4366             constituents.push_back(
4367                 cast(fieldId, fieldType, valueFieldTypeSpec, expectedFieldTypeSpec, nullptr));
4368         }
4369     }
4370     else
4371     {
4372         // Bool types in interface blocks are emulated with uint.  bool<->uint cast is done here.
4373         ASSERT(valueType.getBasicType() == EbtBool);
4374         ASSERT(valueTypeSpec.isOrHasBoolInInterfaceBlock ||
4375                expectedTypeSpec.isOrHasBoolInInterfaceBlock);
4376 
4377         TType emulatedValueType(valueType);
4378         emulatedValueType.setBasicType(EbtUInt);
4379         emulatedValueType.setPrecise(EbpLow);
4380 
4381         // If value is loaded as uint, it needs to change to bool.  If it's bool, it needs to change
4382         // to uint before storage.
4383         if (valueTypeSpec.isOrHasBoolInInterfaceBlock)
4384         {
4385             return castBasicType(value, emulatedValueType, valueType, resultTypeIdOut);
4386         }
4387         else
4388         {
4389             return castBasicType(value, valueType, emulatedValueType, resultTypeIdOut);
4390         }
4391     }
4392 
4393     // Construct the value with the expected type from its cast constituents.
4394     const spirv::IdRef expectedTypeId =
4395         mBuilder.getTypeDataOverrideTypeSpec(valueType, expectedTypeSpec).id;
4396     const spirv::IdRef expectedId = mBuilder.getNewId(mBuilder.getDecorations(valueType));
4397 
4398     spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), expectedTypeId,
4399                                    expectedId, constituents);
4400 
4401     if (resultTypeIdOut)
4402     {
4403         *resultTypeIdOut = expectedTypeId;
4404     }
4405 
4406     return expectedId;
4407 }
4408 
extendScalarParamsToVector(TIntermOperator * node,spirv::IdRef resultTypeId,spirv::IdRefList * parameters)4409 void OutputSPIRVTraverser::extendScalarParamsToVector(TIntermOperator *node,
4410                                                       spirv::IdRef resultTypeId,
4411                                                       spirv::IdRefList *parameters)
4412 {
4413     const TType &type = node->getType();
4414     if (type.isScalar())
4415     {
4416         // Nothing to do if the operation is applied to scalars.
4417         return;
4418     }
4419 
4420     const size_t childCount = node->getChildCount();
4421 
4422     for (size_t childIndex = 0; childIndex < childCount; ++childIndex)
4423     {
4424         const TType &childType = node->getChildNode(childIndex)->getAsTyped()->getType();
4425 
4426         // If the child is a scalar, replicate it to form a vector of the right size.
4427         if (childType.isScalar())
4428         {
4429             TType vectorType(type);
4430             if (vectorType.isMatrix())
4431             {
4432                 vectorType.toMatrixColumnType();
4433             }
4434             (*parameters)[childIndex] = createConstructorVectorFromScalar(
4435                 childType, vectorType, resultTypeId, {{(*parameters)[childIndex]}});
4436         }
4437     }
4438 }
4439 
reduceBoolVector(TOperator op,const spirv::IdRefList & valueIds,spirv::IdRef typeId,const SpirvDecorations & decorations)4440 spirv::IdRef OutputSPIRVTraverser::reduceBoolVector(TOperator op,
4441                                                     const spirv::IdRefList &valueIds,
4442                                                     spirv::IdRef typeId,
4443                                                     const SpirvDecorations &decorations)
4444 {
4445     if (valueIds.size() == 2)
4446     {
4447         // If two values are given, and/or them directly.
4448         WriteBinaryOp writeBinaryOp =
4449             op == EOpEqual ? spirv::WriteLogicalAnd : spirv::WriteLogicalOr;
4450         const spirv::IdRef result = mBuilder.getNewId(decorations);
4451 
4452         writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueIds[0],
4453                       valueIds[1]);
4454         return result;
4455     }
4456 
4457     WriteUnaryOp writeUnaryOp = op == EOpEqual ? spirv::WriteAll : spirv::WriteAny;
4458     spirv::IdRef valueId      = valueIds[0];
4459 
4460     if (valueIds.size() > 2)
4461     {
4462         // If multiple values are given, construct a bool vector out of them first.
4463         const spirv::IdRef bvecTypeId = mBuilder.getBasicTypeId(EbtBool, valueIds.size());
4464         valueId                       = {mBuilder.getNewId(decorations)};
4465 
4466         spirv::WriteCompositeConstruct(mBuilder.getSpirvCurrentFunctionBlock(), bvecTypeId, valueId,
4467                                        valueIds);
4468     }
4469 
4470     const spirv::IdRef result = mBuilder.getNewId(decorations);
4471     writeUnaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, valueId);
4472 
4473     return result;
4474 }
4475 
createCompareImpl(TOperator op,const TType & operandType,spirv::IdRef resultTypeId,spirv::IdRef leftId,spirv::IdRef rightId,const SpirvDecorations & operandDecorations,const SpirvDecorations & resultDecorations,spirv::LiteralIntegerList * currentAccessChain,spirv::IdRefList * intermediateResultsOut)4476 void OutputSPIRVTraverser::createCompareImpl(TOperator op,
4477                                              const TType &operandType,
4478                                              spirv::IdRef resultTypeId,
4479                                              spirv::IdRef leftId,
4480                                              spirv::IdRef rightId,
4481                                              const SpirvDecorations &operandDecorations,
4482                                              const SpirvDecorations &resultDecorations,
4483                                              spirv::LiteralIntegerList *currentAccessChain,
4484                                              spirv::IdRefList *intermediateResultsOut)
4485 {
4486     const TBasicType basicType = operandType.getBasicType();
4487     const bool isFloat         = basicType == EbtFloat;
4488     const bool isBool          = basicType == EbtBool;
4489 
4490     WriteBinaryOp writeBinaryOp = nullptr;
4491 
4492     // For arrays, compare them element by element.
4493     if (operandType.isArray())
4494     {
4495         TType elementType(operandType);
4496         elementType.toArrayElementType();
4497 
4498         currentAccessChain->emplace_back();
4499         for (unsigned int elementIndex = 0; elementIndex < operandType.getOutermostArraySize();
4500              ++elementIndex)
4501         {
4502             // Select the current element.
4503             currentAccessChain->back() = spirv::LiteralInteger(elementIndex);
4504 
4505             // Compare and accumulate the results.
4506             createCompareImpl(op, elementType, resultTypeId, leftId, rightId, operandDecorations,
4507                               resultDecorations, currentAccessChain, intermediateResultsOut);
4508         }
4509         currentAccessChain->pop_back();
4510 
4511         return;
4512     }
4513 
4514     // For structs, compare them field by field.
4515     if (operandType.getStruct() != nullptr)
4516     {
4517         uint32_t fieldIndex = 0;
4518 
4519         currentAccessChain->emplace_back();
4520         for (const TField *field : operandType.getStruct()->fields())
4521         {
4522             // Select the current field.
4523             currentAccessChain->back() = spirv::LiteralInteger(fieldIndex++);
4524 
4525             // Compare and accumulate the results.
4526             createCompareImpl(op, *field->type(), resultTypeId, leftId, rightId, operandDecorations,
4527                               resultDecorations, currentAccessChain, intermediateResultsOut);
4528         }
4529         currentAccessChain->pop_back();
4530 
4531         return;
4532     }
4533 
4534     // For matrices, compare them column by column.
4535     if (operandType.isMatrix())
4536     {
4537         TType columnType(operandType);
4538         columnType.toMatrixColumnType();
4539 
4540         currentAccessChain->emplace_back();
4541         for (uint8_t columnIndex = 0; columnIndex < operandType.getCols(); ++columnIndex)
4542         {
4543             // Select the current column.
4544             currentAccessChain->back() = spirv::LiteralInteger(columnIndex);
4545 
4546             // Compare and accumulate the results.
4547             createCompareImpl(op, columnType, resultTypeId, leftId, rightId, operandDecorations,
4548                               resultDecorations, currentAccessChain, intermediateResultsOut);
4549         }
4550         currentAccessChain->pop_back();
4551 
4552         return;
4553     }
4554 
4555     // For scalars and vectors generate a single instruction for comparison.
4556     if (op == EOpEqual)
4557     {
4558         if (isFloat)
4559             writeBinaryOp = spirv::WriteFOrdEqual;
4560         else if (isBool)
4561             writeBinaryOp = spirv::WriteLogicalEqual;
4562         else
4563             writeBinaryOp = spirv::WriteIEqual;
4564     }
4565     else
4566     {
4567         ASSERT(op == EOpNotEqual);
4568 
4569         if (isFloat)
4570             writeBinaryOp = spirv::WriteFUnordNotEqual;
4571         else if (isBool)
4572             writeBinaryOp = spirv::WriteLogicalNotEqual;
4573         else
4574             writeBinaryOp = spirv::WriteINotEqual;
4575     }
4576 
4577     // Extract the scalar and vector from composite types, if any.
4578     spirv::IdRef leftComponentId  = leftId;
4579     spirv::IdRef rightComponentId = rightId;
4580     if (!currentAccessChain->empty())
4581     {
4582         leftComponentId  = mBuilder.getNewId(operandDecorations);
4583         rightComponentId = mBuilder.getNewId(operandDecorations);
4584 
4585         const spirv::IdRef componentTypeId =
4586             mBuilder.getBasicTypeId(operandType.getBasicType(), operandType.getNominalSize());
4587 
4588         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4589                                      leftComponentId, leftId, *currentAccessChain);
4590         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), componentTypeId,
4591                                      rightComponentId, rightId, *currentAccessChain);
4592     }
4593 
4594     const bool reduceResult     = !operandType.isScalar();
4595     spirv::IdRef result         = mBuilder.getNewId({});
4596     spirv::IdRef opResultTypeId = resultTypeId;
4597     if (reduceResult)
4598     {
4599         opResultTypeId = mBuilder.getBasicTypeId(EbtBool, operandType.getNominalSize());
4600     }
4601 
4602     // Write the comparison operation itself.
4603     writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), opResultTypeId, result, leftComponentId,
4604                   rightComponentId);
4605 
4606     // If it's a vector, reduce the result.
4607     if (reduceResult)
4608     {
4609         result = reduceBoolVector(op, {result}, resultTypeId, resultDecorations);
4610     }
4611 
4612     intermediateResultsOut->push_back(result);
4613 }
4614 
makeBuiltInOutputStructType(TIntermOperator * node,size_t lvalueCount)4615 spirv::IdRef OutputSPIRVTraverser::makeBuiltInOutputStructType(TIntermOperator *node,
4616                                                                size_t lvalueCount)
4617 {
4618     // The built-ins with lvalues are in one of the following forms:
4619     //
4620     // - lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4621     // - builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4622     //
4623     // In SPIR-V, the result of all these instructions is a struct { lsb; msb; }.
4624 
4625     const size_t childCount = node->getChildCount();
4626     ASSERT(childCount >= 2);
4627 
4628     TIntermTyped *lastChild       = node->getChildNode(childCount - 1)->getAsTyped();
4629     TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4630 
4631     const TType &lsbType = lvalueCount == 1 ? node->getType() : lastChild->getType();
4632     const TType &msbType = lvalueCount == 1 ? lastChild->getType() : beforeLastChild->getType();
4633 
4634     ASSERT(lsbType.isScalar() || lsbType.isVector());
4635     ASSERT(msbType.isScalar() || msbType.isVector());
4636 
4637     const BuiltInResultStruct key = {
4638         lsbType.getBasicType(),
4639         msbType.getBasicType(),
4640         static_cast<uint32_t>(lsbType.getNominalSize()),
4641         static_cast<uint32_t>(msbType.getNominalSize()),
4642     };
4643 
4644     auto iter = mBuiltInResultStructMap.find(key);
4645     if (iter == mBuiltInResultStructMap.end())
4646     {
4647         // Create a TStructure and TType for the required structure.
4648         TType *lsbTypeCopy = new TType(lsbType.getBasicType(), lsbType.getNominalSize(), 1);
4649         TType *msbTypeCopy = new TType(msbType.getBasicType(), msbType.getNominalSize(), 1);
4650 
4651         TFieldList *fields = new TFieldList;
4652         fields->push_back(
4653             new TField(lsbTypeCopy, ImmutableString("lsb"), {}, SymbolType::AngleInternal));
4654         fields->push_back(
4655             new TField(msbTypeCopy, ImmutableString("msb"), {}, SymbolType::AngleInternal));
4656 
4657         TStructure *structure =
4658             new TStructure(&mCompiler->getSymbolTable(), ImmutableString("BuiltInResultType"),
4659                            fields, SymbolType::AngleInternal);
4660 
4661         TType structType(structure, true);
4662 
4663         // Get an id for the type and store in the hash map.
4664         const spirv::IdRef structTypeId = mBuilder.getTypeData(structType, {}).id;
4665         iter                            = mBuiltInResultStructMap.insert({key, structTypeId}).first;
4666     }
4667 
4668     return iter->second;
4669 }
4670 
4671 // Once the builtin instruction is generated, the two return values are extracted from the
4672 // struct.  These are written to the return value (if any) and the out parameters.
storeBuiltInStructOutputInParamsAndReturnValue(TIntermOperator * node,size_t lvalueCount,spirv::IdRef structValue,spirv::IdRef returnValue,spirv::IdRef returnValueType)4673 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamsAndReturnValue(
4674     TIntermOperator *node,
4675     size_t lvalueCount,
4676     spirv::IdRef structValue,
4677     spirv::IdRef returnValue,
4678     spirv::IdRef returnValueType)
4679 {
4680     const size_t childCount = node->getChildCount();
4681     ASSERT(childCount >= 2);
4682 
4683     TIntermTyped *lastChild       = node->getChildNode(childCount - 1)->getAsTyped();
4684     TIntermTyped *beforeLastChild = node->getChildNode(childCount - 2)->getAsTyped();
4685 
4686     if (lvalueCount == 1)
4687     {
4688         // The built-in is the form:
4689         //
4690         //     lsb = builtin(..., out msb): These are identified by lvalueCount == 1
4691 
4692         // Field 0 is lsb, which is extracted as the builtin's return value.
4693         spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), returnValueType,
4694                                      returnValue, structValue, {spirv::LiteralInteger(0)});
4695 
4696         // Field 1 is msb, which is extracted and stored through the out parameter.
4697         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4698                                               structValue, 1);
4699     }
4700     else
4701     {
4702         // The built-in is the form:
4703         //
4704         //     builtin(..., out msb, out lsb): These are identified by lvalueCount == 2
4705         ASSERT(lvalueCount == 2);
4706 
4707         // Field 0 is lsb, which is extracted and stored through the second out parameter.
4708         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 1], lastChild,
4709                                               structValue, 0);
4710 
4711         // Field 1 is msb, which is extracted and stored through the first out parameter.
4712         storeBuiltInStructOutputInParamHelper(&mNodeData[mNodeData.size() - 2], beforeLastChild,
4713                                               structValue, 1);
4714     }
4715 }
4716 
storeBuiltInStructOutputInParamHelper(NodeData * data,TIntermTyped * param,spirv::IdRef structValue,uint32_t fieldIndex)4717 void OutputSPIRVTraverser::storeBuiltInStructOutputInParamHelper(NodeData *data,
4718                                                                  TIntermTyped *param,
4719                                                                  spirv::IdRef structValue,
4720                                                                  uint32_t fieldIndex)
4721 {
4722     spirv::IdRef fieldTypeId  = mBuilder.getTypeData(param->getType(), {}).id;
4723     spirv::IdRef fieldValueId = mBuilder.getNewId(mBuilder.getDecorations(param->getType()));
4724 
4725     spirv::WriteCompositeExtract(mBuilder.getSpirvCurrentFunctionBlock(), fieldTypeId, fieldValueId,
4726                                  structValue, {spirv::LiteralInteger(fieldIndex)});
4727 
4728     accessChainStore(data, fieldValueId, param->getType());
4729 }
4730 
visitSymbol(TIntermSymbol * node)4731 void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node)
4732 {
4733     // No-op visits to symbols that are being declared.  They are handled in visitDeclaration.
4734     if (mIsSymbolBeingDeclared)
4735     {
4736         // Make sure this does not affect other symbols, for example in the initializer expression.
4737         mIsSymbolBeingDeclared = false;
4738         return;
4739     }
4740 
4741     mNodeData.emplace_back();
4742 
4743     // The symbol is either:
4744     //
4745     // - A specialization constant
4746     // - A variable (local, varying etc)
4747     // - An interface block
4748     // - A field of an unnamed interface block
4749     //
4750     // Specialization constants in SPIR-V are treated largely like constants, in which case make
4751     // this behave like visitConstantUnion().
4752 
4753     const TType &type                     = node->getType();
4754     const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
4755     const TSymbol *symbol                 = interfaceBlock;
4756     if (interfaceBlock == nullptr)
4757     {
4758         symbol = &node->variable();
4759     }
4760 
4761     // Track the properties that lead to the symbol's specific SPIR-V type based on the GLSL type.
4762     // They are needed to determine the derived type in an access chain, but are not promoted in
4763     // intermediate nodes' TTypes.
4764     SpirvTypeSpec typeSpec;
4765     typeSpec.inferDefaults(type, mCompiler);
4766 
4767     const spirv::IdRef typeId = mBuilder.getTypeData(type, typeSpec).id;
4768 
4769     // If the symbol is a const variable, a const function parameter or specialization constant,
4770     // create an rvalue.
4771     if (type.getQualifier() == EvqConst || type.getQualifier() == EvqParamConst ||
4772         type.getQualifier() == EvqSpecConst)
4773     {
4774         ASSERT(interfaceBlock == nullptr);
4775         ASSERT(mSymbolIdMap.count(symbol) > 0);
4776         nodeDataInitRValue(&mNodeData.back(), mSymbolIdMap[symbol], typeId);
4777         return;
4778     }
4779 
4780     // Otherwise create an lvalue.
4781     spv::StorageClass storageClass;
4782     const spirv::IdRef symbolId = getSymbolIdAndStorageClass(symbol, type, &storageClass);
4783 
4784     nodeDataInitLValue(&mNodeData.back(), symbolId, typeId, storageClass, typeSpec);
4785 
4786     // If a field of a nameless interface block, create an access chain.
4787     if (type.getInterfaceBlock() && !type.isInterfaceBlock())
4788     {
4789         uint32_t fieldIndex = static_cast<uint32_t>(type.getInterfaceBlockFieldIndex());
4790         accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(fieldIndex), typeId);
4791     }
4792 
4793     // Add gl_PerVertex capabilities only if the field is actually used.
4794     switch (type.getQualifier())
4795     {
4796         case EvqClipDistance:
4797             mBuilder.addCapability(spv::CapabilityClipDistance);
4798             break;
4799         case EvqCullDistance:
4800             mBuilder.addCapability(spv::CapabilityCullDistance);
4801             break;
4802         default:
4803             break;
4804     }
4805 }
4806 
visitConstantUnion(TIntermConstantUnion * node)4807 void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node)
4808 {
4809     mNodeData.emplace_back();
4810 
4811     const TType &type = node->getType();
4812 
4813     // Find out the expected type for this constant, so it can be cast right away and not need an
4814     // instruction to do that.
4815     TIntermNode *parent     = getParentNode();
4816     const size_t childIndex = getParentChildIndex(PreVisit);
4817 
4818     TBasicType expectedBasicType = type.getBasicType();
4819     if (parent->getAsAggregate())
4820     {
4821         TIntermAggregate *parentAggregate = parent->getAsAggregate();
4822 
4823         // Note that only constructors can cast a type.  There are two possibilities:
4824         //
4825         // - It's a struct constructor: The basic type must match that of the corresponding field of
4826         //   the struct.
4827         // - It's a non struct constructor: The basic type must match that of the type being
4828         //   constructed.
4829         if (parentAggregate->isConstructor())
4830         {
4831             const TType &parentType     = parentAggregate->getType();
4832             const TStructure *structure = parentType.getStruct();
4833 
4834             if (structure != nullptr && !parentType.isArray())
4835             {
4836                 expectedBasicType = structure->fields()[childIndex]->type()->getBasicType();
4837             }
4838             else
4839             {
4840                 expectedBasicType = parentAggregate->getType().getBasicType();
4841             }
4842         }
4843     }
4844 
4845     const spirv::IdRef typeId  = mBuilder.getTypeData(type, {}).id;
4846     const spirv::IdRef constId = createConstant(type, expectedBasicType, node->getConstantValue(),
4847                                                 node->isConstantNullValue());
4848 
4849     nodeDataInitRValue(&mNodeData.back(), constId, typeId);
4850 }
4851 
visitSwizzle(Visit visit,TIntermSwizzle * node)4852 bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
4853 {
4854     // Constants are expected to be folded.
4855     ASSERT(!node->hasConstantValue());
4856 
4857     if (visit == PreVisit)
4858     {
4859         // Don't add an entry to the stack.  The child will create one, which we won't pop.
4860         return true;
4861     }
4862 
4863     ASSERT(visit == PostVisit);
4864     ASSERT(mNodeData.size() >= 1);
4865 
4866     const TType &vectorType            = node->getOperand()->getType();
4867     const uint8_t vectorComponentCount = static_cast<uint8_t>(vectorType.getNominalSize());
4868     const TVector<int> &swizzle        = node->getSwizzleOffsets();
4869 
4870     // As an optimization, do nothing if the swizzle is selecting all the components of the vector
4871     // in order.
4872     bool isIdentity = swizzle.size() == vectorComponentCount;
4873     for (size_t index = 0; index < swizzle.size(); ++index)
4874     {
4875         isIdentity = isIdentity && static_cast<size_t>(swizzle[index]) == index;
4876     }
4877 
4878     if (isIdentity)
4879     {
4880         return true;
4881     }
4882 
4883     accessChainOnPush(&mNodeData.back(), vectorType, 0);
4884 
4885     const spirv::IdRef typeId =
4886         mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4887 
4888     accessChainPushSwizzle(&mNodeData.back(), swizzle, typeId, vectorComponentCount);
4889 
4890     return true;
4891 }
4892 
visitBinary(Visit visit,TIntermBinary * node)4893 bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
4894 {
4895     // Constants are expected to be folded.
4896     ASSERT(!node->hasConstantValue());
4897 
4898     if (visit == PreVisit)
4899     {
4900         // Don't add an entry to the stack.  The left child will create one, which we won't pop.
4901         return true;
4902     }
4903 
4904     // If this is a variable initialization node, defer any code generation to visitDeclaration.
4905     if (node->getOp() == EOpInitialize)
4906     {
4907         ASSERT(getParentNode()->getAsDeclarationNode() != nullptr);
4908         return true;
4909     }
4910 
4911     if (IsShortCircuitNeeded(node))
4912     {
4913         // For && and ||, if short-circuiting behavior is needed, we need to emulate it with an
4914         // |if| construct.  At this point, the left-hand side is already evaluated, so we need to
4915         // create an appropriate conditional on in-visit and visit the right-hand-side inside the
4916         // conditional block.  On post-visit, OpPhi is used to calculate the result.
4917         if (visit == InVisit)
4918         {
4919             startShortCircuit(node);
4920             return true;
4921         }
4922 
4923         spirv::IdRef typeId;
4924         const spirv::IdRef result = endShortCircuit(node, &typeId);
4925 
4926         // Replace the access chain with an rvalue that's the result.
4927         nodeDataInitRValue(&mNodeData.back(), result, typeId);
4928 
4929         return true;
4930     }
4931 
4932     if (visit == InVisit)
4933     {
4934         // Left child visited.  Take the entry it created as the current node's.
4935         ASSERT(mNodeData.size() >= 1);
4936 
4937         // As an optimization, if the index is EOpIndexDirect*, take the constant index directly and
4938         // add it to the access chain as literal.
4939         switch (node->getOp())
4940         {
4941             default:
4942                 break;
4943 
4944             case EOpIndexDirect:
4945             case EOpIndexDirectStruct:
4946             case EOpIndexDirectInterfaceBlock:
4947                 const uint32_t index = node->getRight()->getAsConstantUnion()->getIConst(0);
4948                 accessChainOnPush(&mNodeData.back(), node->getLeft()->getType(), index);
4949 
4950                 const spirv::IdRef typeId =
4951                     mBuilder.getTypeData(node->getType(), mNodeData.back().accessChain.typeSpec).id;
4952                 accessChainPushLiteral(&mNodeData.back(), spirv::LiteralInteger(index), typeId);
4953 
4954                 // Don't visit the right child, it's already processed.
4955                 return false;
4956         }
4957 
4958         return true;
4959     }
4960 
4961     // There are at least two entries, one for the left node and one for the right one.
4962     ASSERT(mNodeData.size() >= 2);
4963 
4964     SpirvTypeSpec resultTypeSpec;
4965     if (node->getOp() == EOpIndexIndirect || node->getOp() == EOpAssign)
4966     {
4967         if (node->getOp() == EOpIndexIndirect)
4968         {
4969             accessChainOnPush(&mNodeData[mNodeData.size() - 2], node->getLeft()->getType(), 0);
4970         }
4971         resultTypeSpec = mNodeData[mNodeData.size() - 2].accessChain.typeSpec;
4972     }
4973     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), resultTypeSpec).id;
4974 
4975     // For EOpIndex* operations, push the right value as an index to the left value's access chain.
4976     // For the other operations, evaluate the expression.
4977     switch (node->getOp())
4978     {
4979         case EOpIndexDirect:
4980         case EOpIndexDirectStruct:
4981         case EOpIndexDirectInterfaceBlock:
4982             UNREACHABLE();
4983             break;
4984         case EOpIndexIndirect:
4985         {
4986             // Load the index.
4987             const spirv::IdRef rightValue =
4988                 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
4989             mNodeData.pop_back();
4990 
4991             if (!node->getLeft()->getType().isArray() && node->getLeft()->getType().isVector())
4992             {
4993                 accessChainPushDynamicComponent(&mNodeData.back(), rightValue, resultTypeId);
4994             }
4995             else
4996             {
4997                 accessChainPush(&mNodeData.back(), rightValue, resultTypeId);
4998             }
4999             break;
5000         }
5001 
5002         case EOpAssign:
5003         {
5004             // Load the right hand side of assignment.
5005             const spirv::IdRef rightValue =
5006                 accessChainLoad(&mNodeData.back(), node->getRight()->getType(), nullptr);
5007             mNodeData.pop_back();
5008 
5009             // Store into the access chain.  Since the result of the (a = b) expression is b, change
5010             // the access chain to an unindexed rvalue which is |rightValue|.
5011             accessChainStore(&mNodeData.back(), rightValue, node->getLeft()->getType());
5012             nodeDataInitRValue(&mNodeData.back(), rightValue, resultTypeId);
5013             break;
5014         }
5015 
5016         case EOpComma:
5017             // When the expression a,b is visited, all side effects of a and b are already
5018             // processed.  What's left is to to replace the expression with the result of b.  This
5019             // is simply done by dropping the left node and placing the right node as the result.
5020             mNodeData.erase(mNodeData.begin() + mNodeData.size() - 2);
5021             break;
5022 
5023         default:
5024             const spirv::IdRef result = visitOperator(node, resultTypeId);
5025             mNodeData.pop_back();
5026             nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5027             break;
5028     }
5029 
5030     return true;
5031 }
5032 
visitUnary(Visit visit,TIntermUnary * node)5033 bool OutputSPIRVTraverser::visitUnary(Visit visit, TIntermUnary *node)
5034 {
5035     // Constants are expected to be folded.
5036     ASSERT(!node->hasConstantValue());
5037 
5038     // Special case EOpArrayLength.
5039     if (node->getOp() == EOpArrayLength)
5040     {
5041         visitArrayLength(node);
5042 
5043         // Children already visited.
5044         return false;
5045     }
5046 
5047     if (visit == PreVisit)
5048     {
5049         // Don't add an entry to the stack.  The child will create one, which we won't pop.
5050         return true;
5051     }
5052 
5053     // It's a unary operation, so there can't be an InVisit.
5054     ASSERT(visit != InVisit);
5055 
5056     // There is at least on entry for the child.
5057     ASSERT(mNodeData.size() >= 1);
5058 
5059     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5060     const spirv::IdRef result       = visitOperator(node, resultTypeId);
5061 
5062     // Keep the result as rvalue.
5063     nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5064 
5065     return true;
5066 }
5067 
visitTernary(Visit visit,TIntermTernary * node)5068 bool OutputSPIRVTraverser::visitTernary(Visit visit, TIntermTernary *node)
5069 {
5070     if (visit == PreVisit)
5071     {
5072         // Don't add an entry to the stack.  The condition will create one, which we won't pop.
5073         return true;
5074     }
5075 
5076     size_t lastChildIndex = getLastTraversedChildIndex(visit);
5077 
5078     // If the condition was just visited, evaluate it and decide if OpSelect could be used or an
5079     // if-else must be emitted.  OpSelect is only used if neither side has a side effect.  SPIR-V
5080     // prior to 1.4 requires the type to be either scalar or vector.
5081     const TType &type   = node->getType();
5082     bool canUseOpSelect = (type.isScalar() || type.isVector() || mCompileOptions.emitSPIRV14) &&
5083                           !node->getTrueExpression()->hasSideEffects() &&
5084                           !node->getFalseExpression()->hasSideEffects();
5085 
5086     // Don't use OpSelect on buggy drivers.  Technically this is only needed if the two sides don't
5087     // have matching use of RelaxedPrecision, but not worth being precise about it.
5088     if (mCompileOptions.avoidOpSelectWithMismatchingRelaxedPrecision)
5089     {
5090         canUseOpSelect = false;
5091     }
5092 
5093     if (lastChildIndex == 0)
5094     {
5095         const TType &conditionType = node->getCondition()->getType();
5096 
5097         spirv::IdRef typeId;
5098         spirv::IdRef conditionValue = accessChainLoad(&mNodeData.back(), conditionType, &typeId);
5099 
5100         // If OpSelect can be used, keep the condition for later usage.
5101         if (canUseOpSelect)
5102         {
5103             // SPIR-V prior to 1.4 requires that the condition value have as many components as the
5104             // result.  So when selecting between vectors, we must replicate the condition scalar.
5105             if (!mCompileOptions.emitSPIRV14 && type.isVector())
5106             {
5107                 const TType &boolVectorType =
5108                     *StaticType::GetForVec<EbtBool, EbpUndefined>(EvqGlobal, type.getNominalSize());
5109                 typeId =
5110                     mBuilder.getBasicTypeId(conditionType.getBasicType(), type.getNominalSize());
5111                 conditionValue = createConstructorVectorFromScalar(conditionType, boolVectorType,
5112                                                                    typeId, {{conditionValue}});
5113             }
5114             nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5115             return true;
5116         }
5117 
5118         // Otherwise generate an if-else construct.
5119 
5120         // Three blocks necessary; the true, false and merge.
5121         mBuilder.startConditional(3, false, false);
5122 
5123         // Generate the branch instructions.
5124         const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5125 
5126         const spirv::IdRef trueBlockId  = conditional->blockIds[0];
5127         const spirv::IdRef falseBlockId = conditional->blockIds[1];
5128         const spirv::IdRef mergeBlockId = conditional->blockIds.back();
5129 
5130         mBuilder.writeBranchConditional(conditionValue, trueBlockId, falseBlockId, mergeBlockId);
5131         nodeDataInitRValue(&mNodeData.back(), conditionValue, typeId);
5132         return true;
5133     }
5134 
5135     // Load the result of the true or false part, and keep it for the end.  It's either used in
5136     // OpSelect or OpPhi.
5137     spirv::IdRef typeId;
5138     const spirv::IdRef value = accessChainLoad(&mNodeData.back(), type, &typeId);
5139     mNodeData.pop_back();
5140     mNodeData.back().idList.push_back(value);
5141 
5142     // Additionally store the id of block that has produced the result.
5143     mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
5144 
5145     if (!canUseOpSelect)
5146     {
5147         // Move on to the next block.
5148         mBuilder.writeBranchConditionalBlockEnd();
5149     }
5150 
5151     // When done, generate either OpSelect or OpPhi.
5152     if (visit == PostVisit)
5153     {
5154         const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
5155 
5156         ASSERT(mNodeData.back().idList.size() == 4);
5157         const spirv::IdRef trueValue    = mNodeData.back().idList[0].id;
5158         const spirv::IdRef trueBlockId  = mNodeData.back().idList[1].id;
5159         const spirv::IdRef falseValue   = mNodeData.back().idList[2].id;
5160         const spirv::IdRef falseBlockId = mNodeData.back().idList[3].id;
5161 
5162         if (canUseOpSelect)
5163         {
5164             const spirv::IdRef conditionValue = mNodeData.back().baseId;
5165 
5166             spirv::WriteSelect(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5167                                conditionValue, trueValue, falseValue);
5168         }
5169         else
5170         {
5171             spirv::WritePhi(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result,
5172                             {spirv::PairIdRefIdRef{trueValue, trueBlockId},
5173                              spirv::PairIdRefIdRef{falseValue, falseBlockId}});
5174 
5175             mBuilder.endConditional();
5176         }
5177 
5178         // Replace the access chain with an rvalue that's the result.
5179         nodeDataInitRValue(&mNodeData.back(), result, typeId);
5180     }
5181 
5182     return true;
5183 }
5184 
visitIfElse(Visit visit,TIntermIfElse * node)5185 bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
5186 {
5187     // An if condition may or may not have an else block.  When both blocks are present, the
5188     // translation is as follows:
5189     //
5190     // if (cond) { trueBody } else { falseBody }
5191     //
5192     //               // pre-if block
5193     //       %cond = ...
5194     //               OpSelectionMerge %merge None
5195     //               OpBranchConditional %cond %true %false
5196     //
5197     //       %true = OpLabel
5198     //               trueBody
5199     //               OpBranch %merge
5200     //
5201     //      %false = OpLabel
5202     //               falseBody
5203     //               OpBranch %merge
5204     //
5205     //               // post-if block
5206     //       %merge = OpLabel
5207     //
5208     // If the else block is missing, OpBranchConditional will simply jump to %merge on the false
5209     // condition and the %false block is removed.  Due to the way ParseContext prunes compile-time
5210     // constant conditionals, the if block itself may also be missing, which is treated similarly.
5211 
5212     // It's simpler if this function performs the traversal.
5213     ASSERT(visit == PreVisit);
5214 
5215     // Visit the condition.
5216     node->getCondition()->traverse(this);
5217     const spirv::IdRef conditionValue =
5218         accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
5219 
5220     // If both true and false blocks are missing, there's nothing to do.
5221     if (node->getTrueBlock() == nullptr && node->getFalseBlock() == nullptr)
5222     {
5223         return false;
5224     }
5225 
5226     // Create a conditional with maximum 3 blocks, one for the true block (if any), one for the
5227     // else block (if any), and one for the merge block.  getChildCount() works here as it
5228     // produces an identical count.
5229     mBuilder.startConditional(node->getChildCount(), false, false);
5230 
5231     // Generate the branch instructions.
5232     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5233 
5234     const spirv::IdRef mergeBlock = conditional->blockIds.back();
5235     spirv::IdRef trueBlock        = mergeBlock;
5236     spirv::IdRef falseBlock       = mergeBlock;
5237 
5238     size_t nextBlockIndex = 0;
5239     if (node->getTrueBlock())
5240     {
5241         trueBlock = conditional->blockIds[nextBlockIndex++];
5242     }
5243     if (node->getFalseBlock())
5244     {
5245         falseBlock = conditional->blockIds[nextBlockIndex++];
5246     }
5247 
5248     mBuilder.writeBranchConditional(conditionValue, trueBlock, falseBlock, mergeBlock);
5249 
5250     // Visit the true block, if any.
5251     if (node->getTrueBlock())
5252     {
5253         node->getTrueBlock()->traverse(this);
5254         mBuilder.writeBranchConditionalBlockEnd();
5255     }
5256 
5257     // Visit the false block, if any.
5258     if (node->getFalseBlock())
5259     {
5260         node->getFalseBlock()->traverse(this);
5261         mBuilder.writeBranchConditionalBlockEnd();
5262     }
5263 
5264     // Pop from the conditional stack when done.
5265     mBuilder.endConditional();
5266 
5267     // Don't traverse the children, that's done already.
5268     return false;
5269 }
5270 
visitSwitch(Visit visit,TIntermSwitch * node)5271 bool OutputSPIRVTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
5272 {
5273     // Take the following switch:
5274     //
5275     //     switch (c)
5276     //     {
5277     //     case A:
5278     //         ABlock;
5279     //         break;
5280     //     case B:
5281     //     default:
5282     //         BBlock;
5283     //         break;
5284     //     case C:
5285     //         CBlock;
5286     //         // fallthrough
5287     //     case D:
5288     //         DBlock;
5289     //     }
5290     //
5291     // In SPIR-V, this is implemented similarly to the following pseudo-code:
5292     //
5293     //     switch c:
5294     //         A       -> jump %A
5295     //         B       -> jump %B
5296     //         C       -> jump %C
5297     //         D       -> jump %D
5298     //         default -> jump %B
5299     //
5300     //     %A:
5301     //         ABlock
5302     //         jump %merge
5303     //
5304     //     %B:
5305     //         BBlock
5306     //         jump %merge
5307     //
5308     //     %C:
5309     //         CBlock
5310     //         jump %D
5311     //
5312     //     %D:
5313     //         DBlock
5314     //         jump %merge
5315     //
5316     // The OpSwitch instruction contains the jump labels for the default and other cases.  Each
5317     // block either terminates with a jump to the merge block or the next block as fallthrough.
5318     //
5319     //               // pre-switch block
5320     //               OpSelectionMerge %merge None
5321     //               OpSwitch %cond %C A %A B %B C %C D %D
5322     //
5323     //          %A = OpLabel
5324     //               ABlock
5325     //               OpBranch %merge
5326     //
5327     //          %B = OpLabel
5328     //               BBlock
5329     //               OpBranch %merge
5330     //
5331     //          %C = OpLabel
5332     //               CBlock
5333     //               OpBranch %D
5334     //
5335     //          %D = OpLabel
5336     //               DBlock
5337     //               OpBranch %merge
5338 
5339     if (visit == PreVisit)
5340     {
5341         // Artificially add `if (true)` around switches as a driver bug workaround
5342         if (mCompileOptions.wrapSwitchInIfTrue)
5343         {
5344             const spirv::IdRef conditionValue = mBuilder.getBoolConstant(true);
5345             mBuilder.startConditional(2, false, false);
5346             const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5347             const spirv::IdRef trueBlock        = conditional->blockIds[0];
5348             const spirv::IdRef mergeBlock       = conditional->blockIds[1];
5349             mBuilder.writeBranchConditional(conditionValue, trueBlock, mergeBlock, mergeBlock);
5350         }
5351 
5352         // Don't add an entry to the stack.  The condition will create one, which we won't pop.
5353         return true;
5354     }
5355 
5356     // If the condition was just visited, evaluate it and create the switch instruction.
5357     if (visit == InVisit)
5358     {
5359         ASSERT(getLastTraversedChildIndex(visit) == 0);
5360 
5361         const spirv::IdRef conditionValue =
5362             accessChainLoad(&mNodeData.back(), node->getInit()->getType(), nullptr);
5363 
5364         // First, need to find out how many blocks are there in the switch.
5365         const TIntermSequence &statements = *node->getStatementList()->getSequence();
5366         bool lastWasCase                  = true;
5367         size_t blockIndex                 = 0;
5368 
5369         size_t defaultBlockIndex = std::numeric_limits<size_t>::max();
5370         TVector<uint32_t> caseValues;
5371         TVector<size_t> caseBlockIndices;
5372 
5373         for (TIntermNode *statement : statements)
5374         {
5375             TIntermCase *caseLabel = statement->getAsCaseNode();
5376             const bool isCaseLabel = caseLabel != nullptr;
5377 
5378             if (isCaseLabel)
5379             {
5380                 // For every case label, remember its block index.  This is used later to generate
5381                 // the OpSwitch instruction.
5382                 if (caseLabel->hasCondition())
5383                 {
5384                     // All switch conditions are literals.
5385                     TIntermConstantUnion *condition =
5386                         caseLabel->getCondition()->getAsConstantUnion();
5387                     ASSERT(condition != nullptr);
5388 
5389                     TConstantUnion caseValue;
5390                     if (condition->getType().getBasicType() == EbtYuvCscStandardEXT)
5391                     {
5392                         caseValue.setUConst(
5393                             condition->getConstantValue()->getYuvCscStandardEXTConst());
5394                     }
5395                     else
5396                     {
5397                         bool valid = caseValue.cast(EbtUInt, *condition->getConstantValue());
5398                         ASSERT(valid);
5399                     }
5400 
5401                     caseValues.push_back(caseValue.getUConst());
5402                     caseBlockIndices.push_back(blockIndex);
5403                 }
5404                 else
5405                 {
5406                     // Remember the block index of the default case.
5407                     defaultBlockIndex = blockIndex;
5408                 }
5409                 lastWasCase = true;
5410             }
5411             else if (lastWasCase)
5412             {
5413                 // Every time a non-case node is visited and the previous statement was a case node,
5414                 // it's a new block.
5415                 ++blockIndex;
5416                 lastWasCase = false;
5417             }
5418         }
5419 
5420         // Block count is the number of blocks based on cases + 1 for the merge block.
5421         const size_t blockCount = blockIndex + 1;
5422         mBuilder.startConditional(blockCount, false, true);
5423 
5424         // Generate the switch instructions.
5425         const SpirvConditional *conditional = mBuilder.getCurrentConditional();
5426 
5427         // Generate the list of caseValue->blockIndex mapping used by the OpSwitch instruction.  If
5428         // the switch ends in a number of cases with no statements following them, they will
5429         // naturally jump to the merge block!
5430         spirv::PairLiteralIntegerIdRefList switchTargets;
5431 
5432         for (size_t caseIndex = 0; caseIndex < caseValues.size(); ++caseIndex)
5433         {
5434             uint32_t value        = caseValues[caseIndex];
5435             size_t caseBlockIndex = caseBlockIndices[caseIndex];
5436 
5437             switchTargets.push_back(
5438                 {spirv::LiteralInteger(value), conditional->blockIds[caseBlockIndex]});
5439         }
5440 
5441         const spirv::IdRef mergeBlock   = conditional->blockIds.back();
5442         const spirv::IdRef defaultBlock = defaultBlockIndex <= caseValues.size()
5443                                               ? conditional->blockIds[defaultBlockIndex]
5444                                               : mergeBlock;
5445 
5446         mBuilder.writeSwitch(conditionValue, defaultBlock, switchTargets, mergeBlock);
5447         return true;
5448     }
5449 
5450     // Terminate the last block if not already and end the conditional.
5451     mBuilder.writeSwitchCaseBlockEnd();
5452     mBuilder.endConditional();
5453 
5454     if (mCompileOptions.wrapSwitchInIfTrue)
5455     {
5456         mBuilder.writeBranchConditionalBlockEnd();
5457         mBuilder.endConditional();
5458     }
5459 
5460     return true;
5461 }
5462 
visitCase(Visit visit,TIntermCase * node)5463 bool OutputSPIRVTraverser::visitCase(Visit visit, TIntermCase *node)
5464 {
5465     ASSERT(visit == PreVisit);
5466 
5467     mNodeData.emplace_back();
5468 
5469     TIntermBlock *parent    = getParentNode()->getAsBlock();
5470     const size_t childIndex = getParentChildIndex(PreVisit);
5471 
5472     ASSERT(parent);
5473     const TIntermSequence &parentStatements = *parent->getSequence();
5474 
5475     // Check the previous statement.  If it was not a |case|, then a new block is being started so
5476     // handle fallthrough:
5477     //
5478     //     ...
5479     //        statement;
5480     //     case X:         <--- end the previous block here
5481     //     case Y:
5482     //
5483     //
5484     if (childIndex > 0 && parentStatements[childIndex - 1]->getAsCaseNode() == nullptr)
5485     {
5486         mBuilder.writeSwitchCaseBlockEnd();
5487     }
5488 
5489     // Don't traverse the condition, as it was processed in visitSwitch.
5490     return false;
5491 }
5492 
visitBlock(Visit visit,TIntermBlock * node)5493 bool OutputSPIRVTraverser::visitBlock(Visit visit, TIntermBlock *node)
5494 {
5495     // If global block, nothing to do.
5496     if (getCurrentTraversalDepth() == 0)
5497     {
5498         return true;
5499     }
5500 
5501     // Any construct that needs code blocks must have already handled creating the necessary blocks
5502     // and setting the right one "current".  If there's a block opened in GLSL for scoping reasons,
5503     // it's ignored here as there are no scopes within a function in SPIR-V.
5504     if (visit == PreVisit)
5505     {
5506         return node->getChildCount() > 0;
5507     }
5508 
5509     // Any node that needed to generate code has already done so, just clean up its data.  If
5510     // the child node has no effect, it's automatically discarded (such as variable.field[n].x,
5511     // side effects of n already having generated code).
5512     //
5513     // Blocks inside blocks like:
5514     //
5515     //     {
5516     //         statement;
5517     //         {
5518     //             statement2;
5519     //         }
5520     //     }
5521     //
5522     // don't generate nodes.
5523     const size_t childIndex           = getLastTraversedChildIndex(visit);
5524     const TIntermSequence &statements = *node->getSequence();
5525 
5526     if (statements[childIndex]->getAsBlock() == nullptr)
5527     {
5528         mNodeData.pop_back();
5529     }
5530 
5531     return true;
5532 }
5533 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)5534 bool OutputSPIRVTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
5535 {
5536     if (visit == PreVisit)
5537     {
5538         return true;
5539     }
5540 
5541     const TFunction *function = node->getFunction();
5542 
5543     ASSERT(mFunctionIdMap.count(function) > 0);
5544     const FunctionIds &ids = mFunctionIdMap[function];
5545 
5546     // After the prototype is visited, generate the initial code for the function.
5547     if (visit == InVisit)
5548     {
5549         // Declare the function.
5550         spirv::WriteFunction(mBuilder.getSpirvFunctions(), ids.returnTypeId, ids.functionId,
5551                              spv::FunctionControlMaskNone, ids.functionTypeId);
5552 
5553         for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5554         {
5555             const TVariable *paramVariable = function->getParam(paramIndex);
5556 
5557             const spirv::IdRef paramId =
5558                 mBuilder.getNewId(mBuilder.getDecorations(paramVariable->getType()));
5559             spirv::WriteFunctionParameter(mBuilder.getSpirvFunctions(),
5560                                           ids.parameterTypeIds[paramIndex], paramId);
5561 
5562             // Remember the id of the variable for future look up.
5563             ASSERT(mSymbolIdMap.count(paramVariable) == 0);
5564             mSymbolIdMap[paramVariable] = paramId;
5565 
5566             mBuilder.writeDebugName(paramId, mBuilder.getName(paramVariable).data());
5567         }
5568 
5569         mBuilder.startNewFunction(ids.functionId, function);
5570 
5571         // For main(), add a non-semantic instruction at the beginning for any initialization code
5572         // the transformer may want to add.
5573         if (ids.functionId == vk::spirv::kIdEntryPoint &&
5574             mCompiler->getShaderType() != GL_COMPUTE_SHADER)
5575         {
5576             ASSERT(function->isMain());
5577             mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticEnter);
5578         }
5579 
5580         mCurrentFunctionId = ids.functionId;
5581 
5582         return true;
5583     }
5584 
5585     // If no explicit return was specified, add one automatically here.
5586     if (!mBuilder.isCurrentFunctionBlockTerminated())
5587     {
5588         if (function->getReturnType().getBasicType() == EbtVoid)
5589         {
5590             switch (ids.functionId)
5591             {
5592                 case vk::spirv::kIdEntryPoint:
5593                     // For main(), add a non-semantic instruction at the end of the shader.
5594                     markVertexOutputOnShaderEnd();
5595                     break;
5596                 case vk::spirv::kIdXfbEmulationCaptureFunction:
5597                     // For the transform feedback emulation capture function, add a non-semantic
5598                     // instruction before return for the transformer to fill in as necessary.
5599                     mBuilder.writeNonSemanticInstruction(
5600                         vk::spirv::kNonSemanticTransformFeedbackEmulation);
5601                     break;
5602             }
5603 
5604             spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
5605         }
5606         else
5607         {
5608             // GLSL allows functions expecting a return value to miss a return.  In that case,
5609             // return a null constant.
5610             const TType &returnType = function->getReturnType();
5611             spirv::IdRef nullConstant;
5612             if (returnType.isScalar() && !returnType.isArray())
5613             {
5614                 switch (function->getReturnType().getBasicType())
5615                 {
5616                     case EbtFloat:
5617                         nullConstant = mBuilder.getFloatConstant(0);
5618                         break;
5619                     case EbtUInt:
5620                         nullConstant = mBuilder.getUintConstant(0);
5621                         break;
5622                     case EbtInt:
5623                         nullConstant = mBuilder.getIntConstant(0);
5624                         break;
5625                     default:
5626                         break;
5627                 }
5628             }
5629             if (!nullConstant.valid())
5630             {
5631                 nullConstant = mBuilder.getNullConstant(ids.returnTypeId);
5632             }
5633             spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), nullConstant);
5634         }
5635         mBuilder.terminateCurrentFunctionBlock();
5636     }
5637 
5638     mBuilder.assembleSpirvFunctionBlocks();
5639 
5640     // End the function
5641     spirv::WriteFunctionEnd(mBuilder.getSpirvFunctions());
5642 
5643     mCurrentFunctionId = {};
5644 
5645     return true;
5646 }
5647 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)5648 bool OutputSPIRVTraverser::visitGlobalQualifierDeclaration(Visit visit,
5649                                                            TIntermGlobalQualifierDeclaration *node)
5650 {
5651     if (node->isPrecise())
5652     {
5653         // Nothing to do for |precise|.
5654         return false;
5655     }
5656 
5657     // Global qualifier declarations apply to variables that are already declared.  Invariant simply
5658     // adds a decoration to the variable declaration, which can be done right away.  Note that
5659     // invariant cannot be applied to block members like this, except for gl_PerVertex built-ins,
5660     // which are applied to the members directly by DeclarePerVertexBlocks.
5661     ASSERT(node->isInvariant());
5662 
5663     const TVariable *variable = &node->getSymbol()->variable();
5664     ASSERT(mSymbolIdMap.count(variable) > 0);
5665 
5666     const spirv::IdRef variableId = mSymbolIdMap[variable];
5667 
5668     spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationInvariant, {});
5669 
5670     return false;
5671 }
5672 
visitFunctionPrototype(TIntermFunctionPrototype * node)5673 void OutputSPIRVTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
5674 {
5675     const TFunction *function = node->getFunction();
5676 
5677     // If the function was previously forward declared, skip this.
5678     if (mFunctionIdMap.count(function) > 0)
5679     {
5680         return;
5681     }
5682 
5683     FunctionIds ids;
5684 
5685     // Declare the function type
5686     ids.returnTypeId = mBuilder.getTypeData(function->getReturnType(), {}).id;
5687 
5688     spirv::IdRefList paramTypeIds;
5689     for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
5690     {
5691         const TType &paramType = function->getParam(paramIndex)->getType();
5692 
5693         spirv::IdRef paramId = mBuilder.getTypeData(paramType, {}).id;
5694 
5695         // const function parameters are intermediate values, while the rest are "variables"
5696         // with the Function storage class.
5697         if (paramType.getQualifier() != EvqParamConst)
5698         {
5699             const spv::StorageClass storageClass = IsOpaqueType(paramType.getBasicType())
5700                                                        ? spv::StorageClassUniformConstant
5701                                                        : spv::StorageClassFunction;
5702             paramId                              = mBuilder.getTypePointerId(paramId, storageClass);
5703         }
5704 
5705         ids.parameterTypeIds.push_back(paramId);
5706     }
5707 
5708     ids.functionTypeId = mBuilder.getFunctionTypeId(ids.returnTypeId, ids.parameterTypeIds);
5709 
5710     // Allocate an id for the function up-front.
5711     //
5712     // Apply decorations to the return value of the function by applying them to the OpFunction
5713     // instruction.
5714     //
5715     // Note that some functions have predefined ids.
5716     if (function->isMain())
5717     {
5718         ids.functionId = spirv::IdRef(vk::spirv::kIdEntryPoint);
5719     }
5720     else
5721     {
5722         ids.functionId = mBuilder.getReservedOrNewId(
5723             function->uniqueId(), mBuilder.getDecorations(function->getReturnType()));
5724     }
5725 
5726     // Remember the id of the function for future look up.
5727     mFunctionIdMap[function] = ids;
5728 }
5729 
visitAggregate(Visit visit,TIntermAggregate * node)5730 bool OutputSPIRVTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
5731 {
5732     // Constants are expected to be folded.  However, large constructors (such as arrays) are not
5733     // folded and are handled here.
5734     ASSERT(node->getOp() == EOpConstruct || !node->hasConstantValue());
5735 
5736     if (visit == PreVisit)
5737     {
5738         mNodeData.emplace_back();
5739         return true;
5740     }
5741 
5742     // Keep the parameters on the stack.  If a function call contains out or inout parameters, we
5743     // need to know the access chains for the eventual write back to them.
5744     if (visit == InVisit)
5745     {
5746         return true;
5747     }
5748 
5749     // Expect to have accumulated as many parameters as the node requires.
5750     ASSERT(mNodeData.size() > node->getChildCount());
5751 
5752     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
5753     spirv::IdRef result;
5754 
5755     switch (node->getOp())
5756     {
5757         case EOpConstruct:
5758             // Construct a value out of the accumulated parameters.
5759             result = createConstructor(node, resultTypeId);
5760             break;
5761         case EOpCallFunctionInAST:
5762             // Create a call to the function.
5763             result = createFunctionCall(node, resultTypeId);
5764             break;
5765 
5766             // For barrier functions the scope is device, or with the Vulkan memory model, the queue
5767             // family.  We don't use the Vulkan memory model.
5768         case EOpBarrier:
5769             spirv::WriteControlBarrier(
5770                 mBuilder.getSpirvCurrentFunctionBlock(),
5771                 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5772                 mBuilder.getUintConstant(spv::ScopeWorkgroup),
5773                 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5774                                          spv::MemorySemanticsAcquireReleaseMask));
5775             break;
5776         case EOpBarrierTCS:
5777             // Note: The memory scope and semantics are different with the Vulkan memory model,
5778             // which is not supported.
5779             spirv::WriteControlBarrier(mBuilder.getSpirvCurrentFunctionBlock(),
5780                                        mBuilder.getUintConstant(spv::ScopeWorkgroup),
5781                                        mBuilder.getUintConstant(spv::ScopeInvocation),
5782                                        mBuilder.getUintConstant(spv::MemorySemanticsMaskNone));
5783             break;
5784         case EOpMemoryBarrier:
5785         case EOpGroupMemoryBarrier:
5786         {
5787             const spv::Scope scope =
5788                 node->getOp() == EOpMemoryBarrier ? spv::ScopeDevice : spv::ScopeWorkgroup;
5789             spirv::WriteMemoryBarrier(
5790                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(scope),
5791                 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5792                                          spv::MemorySemanticsWorkgroupMemoryMask |
5793                                          spv::MemorySemanticsImageMemoryMask |
5794                                          spv::MemorySemanticsAcquireReleaseMask));
5795             break;
5796         }
5797         case EOpMemoryBarrierBuffer:
5798             spirv::WriteMemoryBarrier(
5799                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5800                 mBuilder.getUintConstant(spv::MemorySemanticsUniformMemoryMask |
5801                                          spv::MemorySemanticsAcquireReleaseMask));
5802             break;
5803         case EOpMemoryBarrierImage:
5804             spirv::WriteMemoryBarrier(
5805                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5806                 mBuilder.getUintConstant(spv::MemorySemanticsImageMemoryMask |
5807                                          spv::MemorySemanticsAcquireReleaseMask));
5808             break;
5809         case EOpMemoryBarrierShared:
5810             spirv::WriteMemoryBarrier(
5811                 mBuilder.getSpirvCurrentFunctionBlock(), mBuilder.getUintConstant(spv::ScopeDevice),
5812                 mBuilder.getUintConstant(spv::MemorySemanticsWorkgroupMemoryMask |
5813                                          spv::MemorySemanticsAcquireReleaseMask));
5814             break;
5815         case EOpMemoryBarrierAtomicCounter:
5816             // Atomic counters are emulated.
5817             UNREACHABLE();
5818             break;
5819 
5820         case EOpEmitVertex:
5821             if (mCurrentFunctionId == vk::spirv::kIdEntryPoint)
5822             {
5823                 // Add a non-semantic instruction before EmitVertex.
5824                 markVertexOutputOnEmitVertex();
5825             }
5826             spirv::WriteEmitVertex(mBuilder.getSpirvCurrentFunctionBlock());
5827             break;
5828         case EOpEndPrimitive:
5829             spirv::WriteEndPrimitive(mBuilder.getSpirvCurrentFunctionBlock());
5830             break;
5831 
5832         case EOpBeginInvocationInterlockARB:
5833             // Set up a "pixel_interlock_ordered" execution mode, as that is the default
5834             // interlocked execution mode in GLSL, and we don't currently expose an option to change
5835             // that.
5836             mBuilder.addExtension(SPIRVExtensions::FragmentShaderInterlockARB);
5837             mBuilder.addCapability(spv::CapabilityFragmentShaderPixelInterlockEXT);
5838             mBuilder.addExecutionMode(spv::ExecutionMode::ExecutionModePixelInterlockOrderedEXT);
5839             // Compile GL_ARB_fragment_shader_interlock to SPV_EXT_fragment_shader_interlock.
5840             spirv::WriteBeginInvocationInterlockEXT(mBuilder.getSpirvCurrentFunctionBlock());
5841             break;
5842 
5843         case EOpEndInvocationInterlockARB:
5844             // Compile GL_ARB_fragment_shader_interlock to SPV_EXT_fragment_shader_interlock.
5845             spirv::WriteEndInvocationInterlockEXT(mBuilder.getSpirvCurrentFunctionBlock());
5846             break;
5847 
5848         default:
5849             result = visitOperator(node, resultTypeId);
5850             break;
5851     }
5852 
5853     // Pop the parameters.
5854     mNodeData.resize(mNodeData.size() - node->getChildCount());
5855 
5856     // Keep the result as rvalue.
5857     nodeDataInitRValue(&mNodeData.back(), result, resultTypeId);
5858 
5859     return false;
5860 }
5861 
visitDeclaration(Visit visit,TIntermDeclaration * node)5862 bool OutputSPIRVTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
5863 {
5864     const TIntermSequence &sequence = *node->getSequence();
5865 
5866     // Enforced by ValidateASTOptions::validateMultiDeclarations.
5867     ASSERT(sequence.size() == 1);
5868 
5869     // Declare specialization constants especially; they don't require processing the left and right
5870     // nodes, and they are like constant declarations with special instructions and decorations.
5871     const TQualifier qualifier = sequence.front()->getAsTyped()->getType().getQualifier();
5872     if (qualifier == EvqSpecConst)
5873     {
5874         declareSpecConst(node);
5875         return false;
5876     }
5877     // Similarly, constant declarations are turned into actual constants.
5878     if (qualifier == EvqConst)
5879     {
5880         declareConst(node);
5881         return false;
5882     }
5883 
5884     // Skip redeclaration of builtins.  They will correctly declare as built-in on first use.
5885     if (mInGlobalScope &&
5886         (qualifier == EvqClipDistance || qualifier == EvqCullDistance || qualifier == EvqFragDepth))
5887     {
5888         return false;
5889     }
5890 
5891     if (!mInGlobalScope && visit == PreVisit)
5892     {
5893         mNodeData.emplace_back();
5894     }
5895 
5896     mIsSymbolBeingDeclared = visit == PreVisit;
5897 
5898     if (visit != PostVisit)
5899     {
5900         return true;
5901     }
5902 
5903     TIntermSymbol *symbol = sequence.front()->getAsSymbolNode();
5904     spirv::IdRef initializerId;
5905     bool initializeWithDeclaration = false;
5906     bool needsQuantizeTo16         = false;
5907 
5908     // Handle declarations with initializer.
5909     if (symbol == nullptr)
5910     {
5911         TIntermBinary *assign = sequence.front()->getAsBinaryNode();
5912         ASSERT(assign != nullptr && assign->getOp() == EOpInitialize);
5913 
5914         symbol = assign->getLeft()->getAsSymbolNode();
5915         ASSERT(symbol != nullptr);
5916 
5917         // In SPIR-V, it's only possible to initialize a variable together with its declaration if
5918         // the initializer is a constant or a global variable.  We ignore the global variable case
5919         // to avoid tracking whether the variable has been modified since the beginning of the
5920         // function.  Since variable declarations are always placed at the beginning of the function
5921         // in SPIR-V, it would be wrong for example to initialize |var| below with the global
5922         // variable at declaration time:
5923         //
5924         //     vec4 global = A;
5925         //     void f()
5926         //     {
5927         //         global = B;
5928         //         {
5929         //             vec4 var = global;
5930         //         }
5931         //     }
5932         //
5933         // So the initializer is only used when declaring a variable when it's a constant
5934         // expression.  Note that if the variable being declared is itself global (and the
5935         // initializer is not constant), a previous AST transformation (DeferGlobalInitializers)
5936         // makes sure their initialization is deferred to the beginning of main.
5937         //
5938         // Additionally, if the variable is being defined inside a loop, the initializer is not used
5939         // as that would prevent it from being reinitialized in the next iteration of the loop.
5940 
5941         TIntermTyped *initializer = assign->getRight();
5942         initializeWithDeclaration =
5943             !mBuilder.isInLoop() &&
5944             (initializer->getAsConstantUnion() != nullptr || initializer->hasConstantValue());
5945 
5946         if (initializeWithDeclaration)
5947         {
5948             // If a constant, take the Id directly.
5949             initializerId = mNodeData.back().baseId;
5950         }
5951         else
5952         {
5953             // Otherwise generate code to load from right hand side expression.
5954             initializerId = accessChainLoad(&mNodeData.back(), symbol->getType(), nullptr);
5955 
5956             // Workaround for issuetracker.google.com/274859104
5957             // ARM SpirV compiler may utilize the RelaxedPrecision of mediump float,
5958             // and chooses to not cast mediump float to 16 bit. This causes deqp test
5959             // dEQP-GLES2.functional.shaders.algorithm.rgb_to_hsl_vertex failed.
5960             // The reason is that GLSL shader code expects below condition to be true:
5961             // mediump float a == mediump float b;
5962             // However, the condition is false after translating to SpirV
5963             // due to one of them is 32 bit, and the other is 16 bit.
5964             // To resolve the deqp test failure, we will add an OpQuantizeToF16
5965             // SpirV instruction to explicitly cast mediump float scalar or mediump float
5966             // vector to 16 bit, if the right-hand-side is a highp float.
5967             if (mCompileOptions.castMediumpFloatTo16Bit)
5968             {
5969                 const TType leftType            = assign->getLeft()->getType();
5970                 const TType rightType           = assign->getRight()->getType();
5971                 const TPrecision leftPrecision  = leftType.getPrecision();
5972                 const TPrecision rightPrecision = rightType.getPrecision();
5973                 const bool isLeftScalarFloat    = leftType.isScalarFloat();
5974                 const bool isLeftVectorFloat = leftType.isVector() && !leftType.isVectorArray() &&
5975                                                leftType.getBasicType() == EbtFloat;
5976 
5977                 if (leftPrecision == TPrecision::EbpMedium &&
5978                     rightPrecision == TPrecision::EbpHigh &&
5979                     (isLeftScalarFloat || isLeftVectorFloat))
5980                 {
5981                     needsQuantizeTo16 = true;
5982                 }
5983             }
5984         }
5985 
5986         // Clean up the initializer data.
5987         mNodeData.pop_back();
5988     }
5989 
5990     const TType &type         = symbol->getType();
5991     const TVariable *variable = &symbol->variable();
5992 
5993     // If this is just a struct declaration (and not a variable declaration), don't declare the
5994     // struct up-front and let it be lazily defined.  If the struct is only used inside an interface
5995     // block for example, this avoids it being doubly defined (once with the unspecified block
5996     // storage and once with interface block's).
5997     if (type.isStructSpecifier() && variable->symbolType() == SymbolType::Empty)
5998     {
5999         return false;
6000     }
6001 
6002     const spirv::IdRef typeId = mBuilder.getTypeData(type, {}).id;
6003 
6004     spv::StorageClass storageClass =
6005         GetStorageClass(mCompileOptions, type, mCompiler->getShaderType());
6006 
6007     SpirvDecorations decorations = mBuilder.getDecorations(type);
6008     if (mBuilder.isInvariantOutput(type))
6009     {
6010         // Apply the Invariant decoration to output variables if specified or if globally enabled.
6011         decorations.push_back(spv::DecorationInvariant);
6012     }
6013     // Apply the declared memory qualifiers.
6014     TMemoryQualifier memoryQualifier = type.getMemoryQualifier();
6015     if (memoryQualifier.coherent)
6016     {
6017         decorations.push_back(spv::DecorationCoherent);
6018     }
6019     if (memoryQualifier.volatileQualifier)
6020     {
6021         decorations.push_back(spv::DecorationVolatile);
6022     }
6023     if (memoryQualifier.restrictQualifier)
6024     {
6025         decorations.push_back(spv::DecorationRestrict);
6026     }
6027     if (memoryQualifier.readonly)
6028     {
6029         decorations.push_back(spv::DecorationNonWritable);
6030     }
6031     if (memoryQualifier.writeonly)
6032     {
6033         decorations.push_back(spv::DecorationNonReadable);
6034     }
6035 
6036     const spirv::IdRef variableId = mBuilder.declareVariable(
6037         typeId, storageClass, decorations, initializeWithDeclaration ? &initializerId : nullptr,
6038         mBuilder.getName(variable).data(), &variable->uniqueId());
6039 
6040     if (!initializeWithDeclaration && initializerId.valid())
6041     {
6042         // If not initializing at the same time as the declaration, issue a store
6043         if (needsQuantizeTo16)
6044         {
6045             // Insert OpQuantizeToF16 instruction to explicitly cast mediump float to 16 bit before
6046             // issuing an OpStore instruction.
6047             const spirv::IdRef quantizeToF16Result =
6048                 mBuilder.getNewId(mBuilder.getDecorations(symbol->getType()));
6049             spirv::WriteQuantizeToF16(mBuilder.getSpirvCurrentFunctionBlock(), typeId,
6050                                       quantizeToF16Result, initializerId);
6051             initializerId = quantizeToF16Result;
6052         }
6053         spirv::WriteStore(mBuilder.getSpirvCurrentFunctionBlock(), variableId, initializerId,
6054                           nullptr);
6055     }
6056 
6057     const bool isShaderInOut = IsShaderIn(type.getQualifier()) || IsShaderOut(type.getQualifier());
6058     const bool isInterfaceBlock = type.getBasicType() == EbtInterfaceBlock;
6059 
6060     // Add decorations, which apply to the element type of arrays, if array.
6061     spirv::IdRef nonArrayTypeId = typeId;
6062     if (type.isArray() && (isShaderInOut || isInterfaceBlock))
6063     {
6064         SpirvType elementType  = mBuilder.getSpirvType(type, {});
6065         elementType.arraySizes = {};
6066         nonArrayTypeId         = mBuilder.getSpirvTypeData(elementType, nullptr).id;
6067     }
6068 
6069     if (isShaderInOut)
6070     {
6071         if (IsShaderIoBlock(type.getQualifier()) && type.isInterfaceBlock())
6072         {
6073             // For gl_PerVertex in particular, write the necessary BuiltIn decorations
6074             if (type.getQualifier() == EvqPerVertexIn || type.getQualifier() == EvqPerVertexOut)
6075             {
6076                 mBuilder.writePerVertexBuiltIns(type, nonArrayTypeId);
6077             }
6078 
6079             // I/O blocks are decorated with Block
6080             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId,
6081                                  spv::DecorationBlock, {});
6082         }
6083         else if (type.getQualifier() == EvqPatchIn || type.getQualifier() == EvqPatchOut)
6084         {
6085             // Tessellation shaders can have their input or output qualified with |patch|.  For I/O
6086             // blocks, the members are decorated instead.
6087             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationPatch,
6088                                  {});
6089         }
6090     }
6091     else if (isInterfaceBlock)
6092     {
6093         // For uniform and buffer variables, with SPIR-V 1.3 add Block and BufferBlock decorations
6094         // respectively.  With SPIR-V 1.4, always add Block.
6095         const spv::Decoration decoration =
6096             mCompileOptions.emitSPIRV14 || type.getQualifier() == EvqUniform
6097                 ? spv::DecorationBlock
6098                 : spv::DecorationBufferBlock;
6099         spirv::WriteDecorate(mBuilder.getSpirvDecorations(), nonArrayTypeId, decoration, {});
6100 
6101         if (type.getQualifier() == EvqBuffer && !memoryQualifier.restrictQualifier &&
6102             mCompileOptions.aliasedUnlessRestrict)
6103         {
6104             // If GLSL does not specify the SSBO has restrict memory qualifier, assume the
6105             // memory qualifier is aliased
6106             // issuetracker.google.com/266235549
6107             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationAliased,
6108                                  {});
6109         }
6110     }
6111     else if (IsImage(type.getBasicType()) && type.getQualifier() == EvqUniform)
6112     {
6113         // If GLSL does not specify the image has restrict memory qualifier, assume the memory
6114         // qualifier is aliased
6115         // issuetracker.google.com/266235549
6116         if (!memoryQualifier.restrictQualifier && mCompileOptions.aliasedUnlessRestrict)
6117         {
6118             spirv::WriteDecorate(mBuilder.getSpirvDecorations(), variableId, spv::DecorationAliased,
6119                                  {});
6120         }
6121     }
6122 
6123     // Write DescriptorSet, Binding, Location etc decorations if necessary.
6124     mBuilder.writeInterfaceVariableDecorations(type, variableId);
6125 
6126     // Remember the id of the variable for future look up.  For interface blocks, also remember the
6127     // id of the interface block.
6128     ASSERT(mSymbolIdMap.count(variable) == 0);
6129     mSymbolIdMap[variable] = variableId;
6130 
6131     if (type.isInterfaceBlock())
6132     {
6133         ASSERT(mSymbolIdMap.count(type.getInterfaceBlock()) == 0);
6134         mSymbolIdMap[type.getInterfaceBlock()] = variableId;
6135     }
6136 
6137     return false;
6138 }
6139 
GetLoopBlocks(const SpirvConditional * conditional,TLoopType loopType,bool hasCondition,spirv::IdRef * headerBlock,spirv::IdRef * condBlock,spirv::IdRef * bodyBlock,spirv::IdRef * continueBlock,spirv::IdRef * mergeBlock)6140 void GetLoopBlocks(const SpirvConditional *conditional,
6141                    TLoopType loopType,
6142                    bool hasCondition,
6143                    spirv::IdRef *headerBlock,
6144                    spirv::IdRef *condBlock,
6145                    spirv::IdRef *bodyBlock,
6146                    spirv::IdRef *continueBlock,
6147                    spirv::IdRef *mergeBlock)
6148 {
6149     // The order of the blocks is for |for| and |while|:
6150     //
6151     //     %header %cond [optional] %body %continue %merge
6152     //
6153     // and for |do-while|:
6154     //
6155     //     %header %body %cond %merge
6156     //
6157     // Note that the |break| target is always the last block and the |continue| target is the one
6158     // before last.
6159     //
6160     // If %continue is not present, all jumps are made to %cond (which is necessarily present).
6161     // If %cond is not present, all jumps are made to %body instead.
6162 
6163     size_t nextBlock = 0;
6164     *headerBlock     = conditional->blockIds[nextBlock++];
6165     // %cond, if any is after header except for |do-while|.
6166     if (loopType != ELoopDoWhile && hasCondition)
6167     {
6168         *condBlock = conditional->blockIds[nextBlock++];
6169     }
6170     *bodyBlock = conditional->blockIds[nextBlock++];
6171     // After the block is either %cond or %continue based on |do-while| or not.
6172     if (loopType != ELoopDoWhile)
6173     {
6174         *continueBlock = conditional->blockIds[nextBlock++];
6175     }
6176     else
6177     {
6178         *condBlock = conditional->blockIds[nextBlock++];
6179     }
6180     *mergeBlock = conditional->blockIds[nextBlock++];
6181 
6182     ASSERT(nextBlock == conditional->blockIds.size());
6183 
6184     if (!continueBlock->valid())
6185     {
6186         ASSERT(condBlock->valid());
6187         *continueBlock = *condBlock;
6188     }
6189     if (!condBlock->valid())
6190     {
6191         *condBlock = *bodyBlock;
6192     }
6193 }
6194 
visitLoop(Visit visit,TIntermLoop * node)6195 bool OutputSPIRVTraverser::visitLoop(Visit visit, TIntermLoop *node)
6196 {
6197     // There are three kinds of loops, and they translate as such:
6198     //
6199     // for (init; cond; expr) body;
6200     //
6201     //               // pre-loop block
6202     //               init
6203     //               OpBranch %header
6204     //
6205     //     %header = OpLabel
6206     //               OpLoopMerge %merge %continue None
6207     //               OpBranch %cond
6208     //
6209     //               // Note: if cond doesn't exist, this section is not generated.  The above
6210     //               // OpBranch would jump directly to %body.
6211     //       %cond = OpLabel
6212     //          %v = cond
6213     //               OpBranchConditional %v %body %merge None
6214     //
6215     //       %body = OpLabel
6216     //               body
6217     //               OpBranch %continue
6218     //
6219     //   %continue = OpLabel
6220     //               expr
6221     //               OpBranch %header
6222     //
6223     //               // post-loop block
6224     //       %merge = OpLabel
6225     //
6226     //
6227     // while (cond) body;
6228     //
6229     //               // pre-for block
6230     //               OpBranch %header
6231     //
6232     //     %header = OpLabel
6233     //               OpLoopMerge %merge %continue None
6234     //               OpBranch %cond
6235     //
6236     //       %cond = OpLabel
6237     //          %v = cond
6238     //               OpBranchConditional %v %body %merge None
6239     //
6240     //       %body = OpLabel
6241     //               body
6242     //               OpBranch %continue
6243     //
6244     //   %continue = OpLabel
6245     //               OpBranch %header
6246     //
6247     //               // post-loop block
6248     //       %merge = OpLabel
6249     //
6250     //
6251     // do body; while (cond);
6252     //
6253     //               // pre-for block
6254     //               OpBranch %header
6255     //
6256     //     %header = OpLabel
6257     //               OpLoopMerge %merge %cond None
6258     //               OpBranch %body
6259     //
6260     //       %body = OpLabel
6261     //               body
6262     //               OpBranch %cond
6263     //
6264     //       %cond = OpLabel
6265     //          %v = cond
6266     //               OpBranchConditional %v %header %merge None
6267     //
6268     //               // post-loop block
6269     //       %merge = OpLabel
6270     //
6271 
6272     // The order of the blocks is not necessarily the same as traversed, so it's much simpler if
6273     // this function enforces traversal in the right order.
6274     ASSERT(visit == PreVisit);
6275     mNodeData.emplace_back();
6276 
6277     const TLoopType loopType = node->getType();
6278 
6279     // The init statement of a for loop is placed in the previous block, so continue generating code
6280     // as-is until that statement is done.
6281     if (node->getInit())
6282     {
6283         ASSERT(loopType == ELoopFor);
6284         node->getInit()->traverse(this);
6285         mNodeData.pop_back();
6286     }
6287 
6288     const bool hasCondition = node->getCondition() != nullptr;
6289 
6290     // Once the init node is visited, if any, we need to set up the loop.
6291     //
6292     // For |for| and |while|, we need %header, %body, %continue and %merge.  For |do-while|, we
6293     // need %header, %body and %merge.  If condition is present, an additional %cond block is
6294     // needed in each case.
6295     const size_t blockCount = (loopType == ELoopDoWhile ? 3 : 4) + (hasCondition ? 1 : 0);
6296     mBuilder.startConditional(blockCount, true, true);
6297 
6298     // Generate the %header block.
6299     const SpirvConditional *conditional = mBuilder.getCurrentConditional();
6300 
6301     spirv::IdRef headerBlock, condBlock, bodyBlock, continueBlock, mergeBlock;
6302     GetLoopBlocks(conditional, loopType, hasCondition, &headerBlock, &condBlock, &bodyBlock,
6303                   &continueBlock, &mergeBlock);
6304 
6305     mBuilder.writeLoopHeader(loopType == ELoopDoWhile ? bodyBlock : condBlock, continueBlock,
6306                              mergeBlock);
6307 
6308     // %cond, if any is after header except for |do-while|.
6309     if (loopType != ELoopDoWhile && hasCondition)
6310     {
6311         node->getCondition()->traverse(this);
6312 
6313         // Generate the branch at the end of the %cond block.
6314         const spirv::IdRef conditionValue =
6315             accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6316         mBuilder.writeLoopConditionEnd(conditionValue, bodyBlock, mergeBlock);
6317 
6318         mNodeData.pop_back();
6319     }
6320 
6321     // Next comes %body.
6322     {
6323         node->getBody()->traverse(this);
6324 
6325         // Generate the branch at the end of the %body block.
6326         mBuilder.writeLoopBodyEnd(continueBlock);
6327     }
6328 
6329     switch (loopType)
6330     {
6331         case ELoopFor:
6332             // For |for| loops, the expression is placed after the body and acts as the continue
6333             // block.
6334             if (node->getExpression())
6335             {
6336                 node->getExpression()->traverse(this);
6337                 mNodeData.pop_back();
6338             }
6339 
6340             // Generate the branch at the end of the %continue block.
6341             mBuilder.writeLoopContinueEnd(headerBlock);
6342             break;
6343 
6344         case ELoopWhile:
6345             // |for| loops have the expression in the continue block and |do-while| loops have their
6346             // condition block act as the loop's continue block.  |while| loops need a branch-only
6347             // continue loop, which is generated here.
6348             mBuilder.writeLoopContinueEnd(headerBlock);
6349             break;
6350 
6351         case ELoopDoWhile:
6352             // For |do-while|, %cond comes last.
6353             ASSERT(hasCondition);
6354             node->getCondition()->traverse(this);
6355 
6356             // Generate the branch at the end of the %cond block.
6357             const spirv::IdRef conditionValue =
6358                 accessChainLoad(&mNodeData.back(), node->getCondition()->getType(), nullptr);
6359             mBuilder.writeLoopConditionEnd(conditionValue, headerBlock, mergeBlock);
6360 
6361             mNodeData.pop_back();
6362             break;
6363     }
6364 
6365     // Pop from the conditional stack when done.
6366     mBuilder.endConditional();
6367 
6368     // Don't traverse the children, that's done already.
6369     return false;
6370 }
6371 
visitBranch(Visit visit,TIntermBranch * node)6372 bool OutputSPIRVTraverser::visitBranch(Visit visit, TIntermBranch *node)
6373 {
6374     if (visit == PreVisit)
6375     {
6376         mNodeData.emplace_back();
6377         return true;
6378     }
6379 
6380     // There is only ever one child at most.
6381     ASSERT(visit != InVisit);
6382 
6383     switch (node->getFlowOp())
6384     {
6385         case EOpKill:
6386             spirv::WriteKill(mBuilder.getSpirvCurrentFunctionBlock());
6387             mBuilder.terminateCurrentFunctionBlock();
6388             break;
6389         case EOpBreak:
6390             spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6391                                mBuilder.getBreakTargetId());
6392             mBuilder.terminateCurrentFunctionBlock();
6393             break;
6394         case EOpContinue:
6395             spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(),
6396                                mBuilder.getContinueTargetId());
6397             mBuilder.terminateCurrentFunctionBlock();
6398             break;
6399         case EOpReturn:
6400             // Evaluate the expression if any, and return.
6401             if (node->getExpression() != nullptr)
6402             {
6403                 ASSERT(mNodeData.size() >= 1);
6404 
6405                 const spirv::IdRef expressionValue =
6406                     accessChainLoad(&mNodeData.back(), node->getExpression()->getType(), nullptr);
6407                 mNodeData.pop_back();
6408 
6409                 spirv::WriteReturnValue(mBuilder.getSpirvCurrentFunctionBlock(), expressionValue);
6410                 mBuilder.terminateCurrentFunctionBlock();
6411             }
6412             else
6413             {
6414                 if (mCurrentFunctionId == vk::spirv::kIdEntryPoint)
6415                 {
6416                     // For main(), add a non-semantic instruction at the end of the shader.
6417                     markVertexOutputOnShaderEnd();
6418                 }
6419                 spirv::WriteReturn(mBuilder.getSpirvCurrentFunctionBlock());
6420                 mBuilder.terminateCurrentFunctionBlock();
6421             }
6422             break;
6423         default:
6424             UNREACHABLE();
6425     }
6426 
6427     return true;
6428 }
6429 
visitPreprocessorDirective(TIntermPreprocessorDirective * node)6430 void OutputSPIRVTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
6431 {
6432     // No preprocessor directives expected at this point.
6433     UNREACHABLE();
6434 }
6435 
markVertexOutputOnShaderEnd()6436 void OutputSPIRVTraverser::markVertexOutputOnShaderEnd()
6437 {
6438     // Output happens in vertex and fragment stages at return from main.
6439     // In geometry shaders, it's done at EmitVertex.
6440     switch (mCompiler->getShaderType())
6441     {
6442         case GL_FRAGMENT_SHADER:
6443         case GL_VERTEX_SHADER:
6444         case GL_TESS_CONTROL_SHADER_EXT:
6445         case GL_TESS_EVALUATION_SHADER_EXT:
6446             mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticOutput);
6447             break;
6448         default:
6449             break;
6450     }
6451 }
6452 
markVertexOutputOnEmitVertex()6453 void OutputSPIRVTraverser::markVertexOutputOnEmitVertex()
6454 {
6455     // Vertex output happens in the geometry stage at EmitVertex.
6456     if (mCompiler->getShaderType() == GL_GEOMETRY_SHADER)
6457     {
6458         mBuilder.writeNonSemanticInstruction(vk::spirv::kNonSemanticOutput);
6459     }
6460 }
6461 
getSpirv()6462 spirv::Blob OutputSPIRVTraverser::getSpirv()
6463 {
6464     spirv::Blob result = mBuilder.getSpirv();
6465 
6466     // Validate that correct SPIR-V was generated
6467     ASSERT(spirv::Validate(result));
6468 
6469 #if ANGLE_DEBUG_SPIRV_GENERATION
6470     // Disassemble and log the generated SPIR-V for debugging.
6471     spvtools::SpirvTools spirvTools(mCompileOptions.emitSPIRV14 ? SPV_ENV_VULKAN_1_1_SPIRV_1_4
6472                                                                 : SPV_ENV_VULKAN_1_1);
6473     std::string readableSpirv;
6474     spirvTools.Disassemble(result, &readableSpirv,
6475                            SPV_BINARY_TO_TEXT_OPTION_COMMENT | SPV_BINARY_TO_TEXT_OPTION_INDENT |
6476                                SPV_BINARY_TO_TEXT_OPTION_NESTED_INDENT);
6477     fprintf(stderr, "%s\n", readableSpirv.c_str());
6478 #endif  // ANGLE_DEBUG_SPIRV_GENERATION
6479 
6480     return result;
6481 }
6482 }  // anonymous namespace
6483 
OutputSPIRV(TCompiler * compiler,TIntermBlock * root,const ShCompileOptions & compileOptions,const angle::HashMap<int,uint32_t> & uniqueToSpirvIdMap,uint32_t firstUnusedSpirvId)6484 bool OutputSPIRV(TCompiler *compiler,
6485                  TIntermBlock *root,
6486                  const ShCompileOptions &compileOptions,
6487                  const angle::HashMap<int, uint32_t> &uniqueToSpirvIdMap,
6488                  uint32_t firstUnusedSpirvId)
6489 {
6490     // Find the list of nodes that require NoContraction (as a result of |precise|).
6491     if (compiler->hasAnyPreciseType())
6492     {
6493         FindPreciseNodes(compiler, root);
6494     }
6495 
6496     // Traverse the tree and generate SPIR-V instructions
6497     OutputSPIRVTraverser traverser(compiler, compileOptions, uniqueToSpirvIdMap,
6498                                    firstUnusedSpirvId);
6499     root->traverse(&traverser);
6500 
6501     // Generate the final SPIR-V and store in the sink
6502     spirv::Blob spirvBlob = traverser.getSpirv();
6503     compiler->getInfoSink().obj.setBinary(std::move(spirvBlob));
6504 
6505     return true;
6506 }
6507 }  // namespace sh
6508