xref: /aosp_15_r20/external/skia/src/sksl/codegen/SkSLSPIRVCodeGenerator.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLSPIRVCodeGenerator.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/private/base/SkTo.h"
14 #include "src/base/SkEnumBitMask.h"
15 #include "src/core/SkChecksum.h"
16 #include "src/core/SkTHash.h"
17 #include "src/core/SkTraceEvent.h"
18 #include "src/sksl/GLSL.std.450.h"
19 #include "src/sksl/SkSLAnalysis.h"
20 #include "src/sksl/SkSLBuiltinTypes.h"
21 #include "src/sksl/SkSLCompiler.h"
22 #include "src/sksl/SkSLConstantFolder.h"
23 #include "src/sksl/SkSLContext.h"
24 #include "src/sksl/SkSLDefines.h"
25 #include "src/sksl/SkSLErrorReporter.h"
26 #include "src/sksl/SkSLIntrinsicList.h"
27 #include "src/sksl/SkSLMemoryLayout.h"
28 #include "src/sksl/SkSLOperator.h"
29 #include "src/sksl/SkSLOutputStream.h"
30 #include "src/sksl/SkSLPool.h"
31 #include "src/sksl/SkSLPosition.h"
32 #include "src/sksl/SkSLProgramSettings.h"
33 #include "src/sksl/SkSLStringStream.h"
34 #include "src/sksl/SkSLUtil.h"
35 #include "src/sksl/analysis/SkSLSpecialization.h"
36 #include "src/sksl/codegen/SkSLCodeGenerator.h"
37 #include "src/sksl/ir/SkSLBinaryExpression.h"
38 #include "src/sksl/ir/SkSLBlock.h"
39 #include "src/sksl/ir/SkSLConstructor.h"
40 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
41 #include "src/sksl/ir/SkSLConstructorCompound.h"
42 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
43 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
44 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
45 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
46 #include "src/sksl/ir/SkSLConstructorSplat.h"
47 #include "src/sksl/ir/SkSLDoStatement.h"
48 #include "src/sksl/ir/SkSLExpression.h"
49 #include "src/sksl/ir/SkSLExpressionStatement.h"
50 #include "src/sksl/ir/SkSLExtension.h"
51 #include "src/sksl/ir/SkSLFieldAccess.h"
52 #include "src/sksl/ir/SkSLFieldSymbol.h"
53 #include "src/sksl/ir/SkSLForStatement.h"
54 #include "src/sksl/ir/SkSLFunctionCall.h"
55 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
56 #include "src/sksl/ir/SkSLFunctionDefinition.h"
57 #include "src/sksl/ir/SkSLIRNode.h"
58 #include "src/sksl/ir/SkSLIfStatement.h"
59 #include "src/sksl/ir/SkSLIndexExpression.h"
60 #include "src/sksl/ir/SkSLInterfaceBlock.h"
61 #include "src/sksl/ir/SkSLLayout.h"
62 #include "src/sksl/ir/SkSLLiteral.h"
63 #include "src/sksl/ir/SkSLModifierFlags.h"
64 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
65 #include "src/sksl/ir/SkSLPoison.h"
66 #include "src/sksl/ir/SkSLPostfixExpression.h"
67 #include "src/sksl/ir/SkSLPrefixExpression.h"
68 #include "src/sksl/ir/SkSLProgram.h"
69 #include "src/sksl/ir/SkSLProgramElement.h"
70 #include "src/sksl/ir/SkSLReturnStatement.h"
71 #include "src/sksl/ir/SkSLSetting.h"
72 #include "src/sksl/ir/SkSLStatement.h"
73 #include "src/sksl/ir/SkSLSwitchCase.h"
74 #include "src/sksl/ir/SkSLSwitchStatement.h"
75 #include "src/sksl/ir/SkSLSwizzle.h"
76 #include "src/sksl/ir/SkSLSymbol.h"
77 #include "src/sksl/ir/SkSLSymbolTable.h"
78 #include "src/sksl/ir/SkSLTernaryExpression.h"
79 #include "src/sksl/ir/SkSLType.h"
80 #include "src/sksl/ir/SkSLVarDeclarations.h"
81 #include "src/sksl/ir/SkSLVariable.h"
82 #include "src/sksl/ir/SkSLVariableReference.h"
83 #include "src/sksl/spirv.h"
84 #include "src/sksl/transform/SkSLTransform.h"
85 #include "src/utils/SkBitSet.h"
86 
87 #include <algorithm>
88 #include <cstdint>
89 #include <cstring>
90 #include <ctype.h>
91 #include <functional>
92 #include <memory>
93 #include <set>
94 #include <string>
95 #include <string_view>
96 #include <tuple>
97 #include <utility>
98 #include <vector>
99 
100 using namespace skia_private;
101 
102 #define kLast_Capability SpvCapabilityMultiViewport
103 
104 constexpr int DEVICE_FRAGCOORDS_BUILTIN = -1000;
105 constexpr int DEVICE_CLOCKWISE_BUILTIN  = -1001;
106 static constexpr SkSL::Layout kDefaultTypeLayout;
107 
108 namespace SkSL {
109 
110 enum class ProgramKind : int8_t;
111 
112 enum class StorageClass {
113     kUniformConstant,
114     kInput,
115     kUniform,
116     kStorageBuffer,
117     kOutput,
118     kWorkgroup,
119     kCrossWorkgroup,
120     kPrivate,
121     kFunction,
122     kGeneric,
123     kPushConstant,
124     kAtomicCounter,
125     kImage,
126 };
127 
get_storage_class_spv_id(StorageClass storageClass)128 static SpvStorageClass get_storage_class_spv_id(StorageClass storageClass) {
129     switch (storageClass) {
130         case StorageClass::kUniformConstant: return SpvStorageClassUniformConstant;
131         case StorageClass::kInput: return SpvStorageClassInput;
132         case StorageClass::kUniform: return SpvStorageClassUniform;
133         // Note: In SPIR-V 1.3, a storage buffer can be declared with the "StorageBuffer"
134         // storage class and the "Block" decoration and the <1.3 approach we use here ("Uniform"
135         // storage class and the "BufferBlock" decoration) is deprecated. Since we target SPIR-V
136         // 1.0, we have to use the deprecated approach which is well supported in Vulkan and
137         // addresses SkSL use cases (notably SkSL currently doesn't support pointer features that
138         // would benefit from SPV_KHR_variable_pointers capabilities).
139         case StorageClass::kStorageBuffer: return SpvStorageClassUniform;
140         case StorageClass::kOutput: return SpvStorageClassOutput;
141         case StorageClass::kWorkgroup: return SpvStorageClassWorkgroup;
142         case StorageClass::kCrossWorkgroup: return SpvStorageClassCrossWorkgroup;
143         case StorageClass::kPrivate: return SpvStorageClassPrivate;
144         case StorageClass::kFunction: return SpvStorageClassFunction;
145         case StorageClass::kGeneric: return SpvStorageClassGeneric;
146         case StorageClass::kPushConstant: return SpvStorageClassPushConstant;
147         case StorageClass::kAtomicCounter: return SpvStorageClassAtomicCounter;
148         case StorageClass::kImage: return SpvStorageClassImage;
149     }
150 
151     SkUNREACHABLE;
152 }
153 
154 class SPIRVCodeGenerator : public CodeGenerator {
155 public:
156     // We reserve an impossible SpvId as a sentinel. (NA meaning none, n/a, etc.)
157     static constexpr SpvId NA = (SpvId)-1;
158 
159     class LValue {
160     public:
~LValue()161         virtual ~LValue() {}
162 
163         // returns a pointer to the lvalue, if possible. If the lvalue cannot be directly referenced
164         // by a pointer (e.g. vector swizzles), returns NA.
getPointer()165         virtual SpvId getPointer() { return NA; }
166 
167         // Returns true if a valid pointer returned by getPointer represents a memory object
168         // (see https://github.com/KhronosGroup/SPIRV-Tools/issues/2892). Has no meaning if
169         // getPointer() returns NA.
isMemoryObjectPointer() const170         virtual bool isMemoryObjectPointer() const { return true; }
171 
172         // Applies a swizzle to the components of the LValue, if possible. This is used to create
173         // LValues that are swizzes-of-swizzles. Non-swizzle LValues can just return false.
applySwizzle(const ComponentArray & components,const Type & newType)174         virtual bool applySwizzle(const ComponentArray& components, const Type& newType) {
175             return false;
176         }
177 
178         // Returns the storage class of the lvalue.
179         virtual StorageClass storageClass() const = 0;
180 
181         virtual SpvId load(OutputStream& out) = 0;
182 
183         virtual void store(SpvId value, OutputStream& out) = 0;
184     };
185 
SPIRVCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out)186     SPIRVCodeGenerator(const Context* context,
187                        const ShaderCaps* caps,
188                        const Program* program,
189                        OutputStream* out)
190             : CodeGenerator(context, caps, program, out) {}
191 
192     bool generateCode() override;
193 
194 private:
195     enum IntrinsicOpcodeKind {
196         kGLSL_STD_450_IntrinsicOpcodeKind,
197         kSPIRV_IntrinsicOpcodeKind,
198         kSpecial_IntrinsicOpcodeKind,
199         kInvalid_IntrinsicOpcodeKind,
200     };
201 
202     enum SpecialIntrinsic {
203         kAtan_SpecialIntrinsic,
204         kClamp_SpecialIntrinsic,
205         kMatrixCompMult_SpecialIntrinsic,
206         kMax_SpecialIntrinsic,
207         kMin_SpecialIntrinsic,
208         kMix_SpecialIntrinsic,
209         kMod_SpecialIntrinsic,
210         kDFdy_SpecialIntrinsic,
211         kSaturate_SpecialIntrinsic,
212         kSampledImage_SpecialIntrinsic,
213         kSmoothStep_SpecialIntrinsic,
214         kStep_SpecialIntrinsic,
215         kSubpassLoad_SpecialIntrinsic,
216         kTexture_SpecialIntrinsic,
217         kTextureGrad_SpecialIntrinsic,
218         kTextureLod_SpecialIntrinsic,
219         kTextureRead_SpecialIntrinsic,
220         kTextureWrite_SpecialIntrinsic,
221         kTextureWidth_SpecialIntrinsic,
222         kTextureHeight_SpecialIntrinsic,
223         kAtomicAdd_SpecialIntrinsic,
224         kAtomicLoad_SpecialIntrinsic,
225         kAtomicStore_SpecialIntrinsic,
226         kStorageBarrier_SpecialIntrinsic,
227         kWorkgroupBarrier_SpecialIntrinsic,
228     };
229 
230     enum class Precision {
231         kDefault,
232         kRelaxed,
233     };
234 
235     struct TempVar {
236         SpvId spvId;
237         const Type* type;
238         std::unique_ptr<SPIRVCodeGenerator::LValue> lvalue;
239     };
240 
241     /**
242      * Pass in the type to automatically add a RelaxedPrecision decoration for the id when
243      * appropriate, or null to never add one.
244      */
245     SpvId nextId(const Type* type);
246 
247     SpvId nextId(Precision precision);
248 
249     SpvId getType(const Type& type);
250 
251     SpvId getType(const Type& type, const Layout& typeLayout, const MemoryLayout& memoryLayout);
252 
253     SpvId getFunctionType(const FunctionDeclaration& function);
254 
255     SpvId getFunctionParameterType(const Type& parameterType, const Layout& parameterLayout);
256 
257     SpvId getPointerType(const Type& type, StorageClass storageClass);
258 
259     SpvId getPointerType(const Type& type,
260                          const Layout& typeLayout,
261                          const MemoryLayout& memoryLayout,
262                          StorageClass storageClass);
263 
264     StorageClass getStorageClass(const Expression& expr);
265 
266     TArray<SpvId> getAccessChain(const Expression& expr, OutputStream& out);
267 
268     void writeLayout(const Layout& layout, SpvId target, Position pos);
269 
270     void writeFieldLayout(const Layout& layout, SpvId target, int member);
271 
272     SpvId writeStruct(const Type& type, const MemoryLayout& memoryLayout);
273 
274     void writeProgramElement(const ProgramElement& pe, OutputStream& out);
275 
276     SpvId writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip = true);
277 
278     void writeFunctionStart(const FunctionDeclaration& f, OutputStream& out);
279 
280     SpvId writeFunctionDeclaration(const FunctionDeclaration& f, OutputStream& out);
281 
282     void writeFunction(const FunctionDefinition& f, OutputStream& out);
283 
284     // Writes the function with the defined specializationIndex, if the index is -1, then it is
285     // assumed that the function has no specializations.
286     void writeFunctionInstantiation(const FunctionDefinition& f,
287                                     Analysis::SpecializationIndex specializationIndex,
288                                     const Analysis::SpecializedParameters* specializedParams,
289                                     OutputStream& out);
290 
291     bool writeGlobalVarDeclaration(ProgramKind kind, const VarDeclaration& v);
292 
293     SpvId writeGlobalVar(ProgramKind kind, StorageClass, const Variable& v);
294 
295     void writeVarDeclaration(const VarDeclaration& var, OutputStream& out);
296 
297     SpvId writeVariableReference(const VariableReference& ref, OutputStream& out);
298 
299     int findUniformFieldIndex(const Variable& var) const;
300 
301     std::unique_ptr<LValue> getLValue(const Expression& value, OutputStream& out);
302 
303     SpvId writeExpression(const Expression& expr, OutputStream& out);
304 
305     SpvId writeIntrinsicCall(const FunctionCall& c, OutputStream& out);
306 
307     void writeFunctionCallArgument(TArray<SpvId>& argumentList,
308                                    const FunctionCall& call,
309                                    int argIndex,
310                                    std::vector<TempVar>* tempVars,
311                                    const SkBitSet* specializedParams,
312                                    OutputStream& out);
313 
314     void copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out);
315 
316     SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out);
317 
318 
319     void writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
320                                       SpvId signedInst, SpvId unsignedInst,
321                                       const TArray<SpvId>& args, OutputStream& out);
322 
323     /**
324      * Promotes an expression to a vector. If the expression is already a vector with vectorSize
325      * columns, returns it unmodified. If the expression is a scalar, either promotes it to a
326      * vector (if vectorSize > 1) or returns it unmodified (if vectorSize == 1). Asserts if the
327      * expression is already a vector and it does not have vectorSize columns.
328      */
329     SpvId vectorize(const Expression& expr, int vectorSize, OutputStream& out);
330 
331     /**
332      * Given a list of potentially mixed scalars and vectors, promotes the scalars to match the
333      * size of the vectors and returns the ids of the written expressions. e.g. given (float, vec2),
334      * returns (vec2(float), vec2). It is an error to use mismatched vector sizes, e.g. (float,
335      * vec2, vec3).
336      */
337     TArray<SpvId> vectorize(const ExpressionArray& args, OutputStream& out);
338 
339     /**
340      * Given a SpvId of a scalar, splats it across the passed-in type (scalar, vector or matrix) and
341      * returns the SpvId of the new value.
342      */
343     SpvId splat(const Type& type, SpvId id, OutputStream& out);
344 
345     SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out);
346     SpvId writeAtomicIntrinsic(const FunctionCall& c,
347                                SpecialIntrinsic kind,
348                                SpvId resultId,
349                                OutputStream& out);
350 
351     SpvId castScalarToFloat(SpvId inputId, const Type& inputType, const Type& outputType,
352                             OutputStream& out);
353 
354     SpvId castScalarToSignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
355                                 OutputStream& out);
356 
357     SpvId castScalarToUnsignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
358                                   OutputStream& out);
359 
360     SpvId castScalarToBoolean(SpvId inputId, const Type& inputType, const Type& outputType,
361                               OutputStream& out);
362 
363     SpvId castScalarToType(SpvId inputExprId, const Type& inputType, const Type& outputType,
364                            OutputStream& out);
365 
366     /**
367      * Writes a potentially-different-sized copy of a matrix. Entries which do not exist in the
368      * source matrix are filled with zero; entries which do not exist in the destination matrix are
369      * ignored.
370      */
371     SpvId writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType, OutputStream& out);
372 
373     void addColumnEntry(const Type& columnType,
374                         TArray<SpvId>* currentColumn,
375                         TArray<SpvId>* columnIds,
376                         int rows,
377                         SpvId entry,
378                         OutputStream& out);
379 
380     SpvId writeConstructorCompound(const ConstructorCompound& c, OutputStream& out);
381 
382     SpvId writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out);
383 
384     SpvId writeVectorConstructor(const ConstructorCompound& c, OutputStream& out);
385 
386     SpvId writeCompositeConstructor(const AnyConstructor& c, OutputStream& out);
387 
388     SpvId writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, OutputStream& out);
389 
390     SpvId writeConstructorMatrixResize(const ConstructorMatrixResize& c, OutputStream& out);
391 
392     SpvId writeConstructorScalarCast(const ConstructorScalarCast& c, OutputStream& out);
393 
394     SpvId writeConstructorSplat(const ConstructorSplat& c, OutputStream& out);
395 
396     SpvId writeConstructorCompoundCast(const ConstructorCompoundCast& c, OutputStream& out);
397 
398     SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
399 
400     SpvId writeSwizzle(const Expression& baseExpr,
401                        const ComponentArray& components,
402                        OutputStream& out);
403 
404     SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
405 
406     /**
407      * Folds the potentially-vector result of a logical operation down to a single bool. If
408      * operandType is a vector type, assumes that the intermediate result in id is a bvec of the
409      * same dimensions, and applys all() to it to fold it down to a single bool value. Otherwise,
410      * returns the original id value.
411      */
412     SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out);
413 
414     SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator,
415                                 SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
416                                 SpvOp_ mergeOperator, OutputStream& out);
417 
418     SpvId writeStructComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
419                                 OutputStream& out);
420 
421     SpvId writeArrayComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
422                                OutputStream& out);
423 
424     // Used by writeStructComparison and writeArrayComparison to logically combine field-by-field
425     // comparisons into an overall comparison result.
426     // - `a.x == b.x` merged with `a.y == b.y` generates `(a.x == b.x) && (a.y == b.y)`
427     // - `a.x != b.x` merged with `a.y != b.y` generates `(a.x != b.x) || (a.y != b.y)`
428     SpvId mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op, OutputStream& out);
429 
430     // When the RewriteMatrixVectorMultiply caps bit is set, we manually decompose the M*V
431     // multiplication into a sum of vector-scalar products.
432     SpvId writeDecomposedMatrixVectorMultiply(const Type& leftType,
433                                               SpvId lhs,
434                                               const Type& rightType,
435                                               SpvId rhs,
436                                               const Type& resultType,
437                                               OutputStream& out);
438 
439     SpvId writeComponentwiseMatrixUnary(const Type& operandType,
440                                         SpvId operand,
441                                         SpvOp_ op,
442                                         OutputStream& out);
443 
444     SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
445                                          SpvOp_ op, OutputStream& out);
446 
447     SpvId writeBinaryOperationComponentwiseIfMatrix(const Type& resultType, const Type& operandType,
448                                                     SpvId lhs, SpvId rhs,
449                                                     SpvOp_ ifFloat, SpvOp_ ifInt,
450                                                     SpvOp_ ifUInt, SpvOp_ ifBool,
451                                                     OutputStream& out);
452 
453     SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
454                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
455                                SpvOp_ ifBool, OutputStream& out);
456 
457     SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
458                                SpvId rhs, bool writeComponentwiseIfMatrix, SpvOp_ ifFloat,
459                                SpvOp_ ifInt, SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out);
460 
461     SpvId writeReciprocal(const Type& type, SpvId value, OutputStream& out);
462 
463     SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
464                                 const Type& rightType, SpvId rhs, const Type& resultType,
465                                 OutputStream& out);
466 
467     SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
468 
469     SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);
470 
471     SpvId writeIndexExpression(const IndexExpression& expr, OutputStream& out);
472 
473     SpvId writeLogicalAnd(const Expression& left, const Expression& right, OutputStream& out);
474 
475     SpvId writeLogicalOr(const Expression& left, const Expression& right, OutputStream& out);
476 
477     SpvId writePrefixExpression(const PrefixExpression& p, OutputStream& out);
478 
479     SpvId writePostfixExpression(const PostfixExpression& p, OutputStream& out);
480 
481     SpvId writeLiteral(const Literal& f);
482 
483     SpvId writeLiteral(double value, const Type& type);
484 
485     void writeStatement(const Statement& s, OutputStream& out);
486 
487     void writeBlock(const Block& b, OutputStream& out);
488 
489     void writeIfStatement(const IfStatement& stmt, OutputStream& out);
490 
491     void writeForStatement(const ForStatement& f, OutputStream& out);
492 
493     void writeDoStatement(const DoStatement& d, OutputStream& out);
494 
495     void writeSwitchStatement(const SwitchStatement& s, OutputStream& out);
496 
497     void writeReturnStatement(const ReturnStatement& r, OutputStream& out);
498 
499     void writeCapabilities(OutputStream& out);
500 
501     void writeInstructions(const Program& program, OutputStream& out);
502 
503     void writeOpCode(SpvOp_ opCode, int length, OutputStream& out);
504 
505     void writeWord(int32_t word, OutputStream& out);
506 
507     void writeString(std::string_view s, OutputStream& out);
508 
509     void writeInstruction(SpvOp_ opCode, OutputStream& out);
510 
511     void writeInstruction(SpvOp_ opCode, std::string_view string, OutputStream& out);
512 
513     void writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out);
514 
515     void writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
516                           OutputStream& out);
517 
518     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, std::string_view string,
519                           OutputStream& out);
520 
521     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, OutputStream& out);
522 
523     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3,
524                           OutputStream& out);
525 
526     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
527                           OutputStream& out);
528 
529     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
530                           int32_t word5, OutputStream& out);
531 
532     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
533                           int32_t word5, int32_t word6, OutputStream& out);
534 
535     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
536                           int32_t word5, int32_t word6, int32_t word7, OutputStream& out);
537 
538     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
539                           int32_t word5, int32_t word6, int32_t word7, int32_t word8,
540                           OutputStream& out);
541 
542     // This form of writeInstruction can deduplicate redundant ops.
543     struct Word;
544     // 8 Words is enough for nearly all instructions (except variable-length instructions like
545     // OpAccessChain or OpConstantComposite).
546     using Words = STArray<8, Word, true>;
547     SpvId writeInstruction(SpvOp_ opCode, const TArray<Word, true>& words, OutputStream& out);
548 
549     struct Instruction {
550         SpvId fOp;
551         int32_t fResultKind;
552         STArray<8, int32_t>  fWords;
553 
554         bool operator==(const Instruction& that) const;
555         struct Hash;
556     };
557 
558     static Instruction BuildInstructionKey(SpvOp_ opCode, const TArray<Word, true>& words);
559 
560     // The writeOpXxxxx calls will simplify and deduplicate ops where possible.
561     SpvId writeOpConstantTrue(const Type& type);
562     SpvId writeOpConstantFalse(const Type& type);
563     SpvId writeOpConstant(const Type& type, int32_t valueBits);
564     SpvId writeOpConstantComposite(const Type& type, const TArray<SpvId>& values);
565     SpvId writeOpCompositeConstruct(const Type& type, const TArray<SpvId>&, OutputStream& out);
566     SpvId writeOpCompositeExtract(const Type& type, SpvId base, int component, OutputStream& out);
567     SpvId writeOpCompositeExtract(const Type& type, SpvId base, int componentA, int componentB,
568                                   OutputStream& out);
569     SpvId writeOpLoad(SpvId type, Precision precision, SpvId pointer, OutputStream& out);
570     void writeOpStore(StorageClass storageClass, SpvId pointer, SpvId value, OutputStream& out);
571 
572     // Converts the provided SpvId(s) into an array of scalar OpConstants, if it can be done.
573     bool toConstants(SpvId value, TArray<SpvId>* constants);
574     bool toConstants(SkSpan<const SpvId> values, TArray<SpvId>* constants);
575 
576     // Extracts the requested component SpvId from a composite instruction, if it can be done.
577     Instruction* resultTypeForInstruction(const Instruction& instr);
578     int numComponentsForVecInstruction(const Instruction& instr);
579     SpvId toComponent(SpvId id, int component);
580 
581     struct ConditionalOpCounts {
582         int numReachableOps;
583         int numStoreOps;
584     };
585     ConditionalOpCounts getConditionalOpCounts();
586     void pruneConditionalOps(ConditionalOpCounts ops);
587 
588     enum StraightLineLabelType {
589         // Use "BranchlessBlock" for blocks which are never explicitly branched-to at all. This
590         // happens at the start of a function, or when we find unreachable code.
591         kBranchlessBlock,
592 
593         // Use "BranchIsOnPreviousLine" when writing a label that comes immediately after its
594         // associated branch. Example usage:
595         // - SPIR-V does not implicitly fall through from one block to the next, so you may need to
596         //   use an OpBranch to explicitly jump to the next block, even when they are adjacent in
597         //   the code.
598         // - The block immediately following an OpBranchConditional or OpSwitch.
599         kBranchIsOnPreviousLine,
600     };
601 
602     enum BranchingLabelType {
603         // Use "BranchIsAbove" for labels which are referenced by OpBranch or OpBranchConditional
604         // ops that are above the label in the code--i.e., the branch skips forward in the code.
605         kBranchIsAbove,
606 
607         // Use "BranchIsBelow" for labels which are referenced by OpBranch or OpBranchConditional
608         // ops below the label in the code--i.e., the branch jumps backward in the code.
609         kBranchIsBelow,
610 
611         // Use "BranchesOnBothSides" for labels which have branches coming from both directions.
612         kBranchesOnBothSides,
613     };
614     void writeLabel(SpvId label, StraightLineLabelType type, OutputStream& out);
615     void writeLabel(SpvId label, BranchingLabelType type, ConditionalOpCounts ops,
616                     OutputStream& out);
617 
618     MemoryLayout memoryLayoutForStorageClass(StorageClass storageClass);
619     MemoryLayout memoryLayoutForVariable(const Variable&) const;
620 
621     struct EntrypointAdapter {
622         std::unique_ptr<FunctionDefinition> entrypointDef;
623         std::unique_ptr<FunctionDeclaration> entrypointDecl;
624     };
625 
626     EntrypointAdapter writeEntrypointAdapter(const FunctionDeclaration& main);
627 
628     struct UniformBuffer {
629         std::unique_ptr<InterfaceBlock> fInterfaceBlock;
630         std::unique_ptr<Variable> fInnerVariable;
631         std::unique_ptr<Type> fStruct;
632     };
633 
634     void writeUniformBuffer(SymbolTable* topLevelSymbolTable);
635 
636     void addRTFlipUniform(Position pos);
637 
638     std::unique_ptr<Expression> identifier(std::string_view name);
639 
640     std::tuple<const Variable*, const Variable*> synthesizeTextureAndSampler(
641             const Variable& combinedSampler);
642 
643     const MemoryLayout fDefaultMemoryLayout{MemoryLayout::Standard::k140};
644 
645     uint64_t fCapabilities = 0;
646     SpvId fIdCount = 1;
647     SpvId fGLSLExtendedInstructions;
648     struct Intrinsic {
649         IntrinsicOpcodeKind opKind;
650         int32_t floatOp;
651         int32_t signedOp;
652         int32_t unsignedOp;
653         int32_t boolOp;
654     };
655     Intrinsic getIntrinsic(IntrinsicKind) const;
656 
657     THashMap<Analysis::SpecializedFunctionKey, SpvId, Analysis::SpecializedFunctionKey::Hash>
658             fFunctionMap;
659 
660     Analysis::SpecializationInfo fSpecializationInfo;
661     Analysis::SpecializationIndex fActiveSpecializationIndex = Analysis::kUnspecialized;
662     const Analysis::SpecializedParameters* fActiveSpecialization = nullptr;
663 
664     THashMap<const Variable*, SpvId> fVariableMap;
665     THashMap<const Type*, SpvId> fStructMap;
666     StringStream fGlobalInitializersBuffer;
667     StringStream fConstantBuffer;
668     StringStream fVariableBuffer;
669     StringStream fNameBuffer;
670     StringStream fDecorationBuffer;
671 
672     // Mapping from combined sampler declarations to synthesized texture/sampler variables.
673     // This is used when the sampler is declared as `layout(webgpu)` or `layout(direct3d)`.
674     bool fUseTextureSamplerPairs = false;
675     struct SynthesizedTextureSamplerPair {
676         // The names of the synthesized variables. The Variable objects themselves store string
677         // views referencing these strings. It is important for the std::string instances to have a
678         // fixed memory location after the string views get created, which is why
679         // `fSynthesizedSamplerMap` stores unique_ptr instead of values.
680         std::string fTextureName;
681         std::string fSamplerName;
682         std::unique_ptr<Variable> fTexture;
683         std::unique_ptr<Variable> fSampler;
684     };
685     THashMap<const Variable*, std::unique_ptr<SynthesizedTextureSamplerPair>>
686             fSynthesizedSamplerMap;
687 
688     // These caches map SpvIds to Instructions, and vice-versa. This enables us to deduplicate code
689     // (by detecting an Instruction we've already issued and reusing the SpvId), and to introspect
690     // and simplify code we've already emitted  (by taking a SpvId from an Instruction and following
691     // it back to its source).
692 
693     // A map of instruction -> SpvId:
694     THashMap<Instruction, SpvId, Instruction::Hash> fOpCache;
695     // A map of SpvId -> instruction:
696     THashMap<SpvId, Instruction> fSpvIdCache;
697     // A map of SpvId -> value SpvId:
698     THashMap<SpvId, SpvId> fStoreCache;
699 
700     // "Reachable" ops are instructions which can safely be accessed from the current block.
701     // For instance, if our SPIR-V contains `%3 = OpFAdd %1 %2`, we would be able to access and
702     // reuse that computation on following lines. However, if that Add operation occurred inside an
703     // `if` block, then its SpvId becomes inaccessible once we complete the if statement (since
704     // depending on the if condition, we may or may not have actually done that computation). The
705     // same logic applies to other control-flow blocks as well. Once an instruction becomes
706     // unreachable, we remove it from both op-caches.
707     TArray<SpvId> fReachableOps;
708 
709     // The "store-ops" list contains a running list of all the pointers in the store cache. If a
710     // store occurs inside of a conditional block, once that block exits, we no longer know what is
711     // stored in that particular SpvId. At that point, we must remove any associated entry from the
712     // store cache.
713     TArray<SpvId> fStoreOps;
714 
715     // label of the current block, or 0 if we are not in a block
716     SpvId fCurrentBlock = 0;
717     TArray<SpvId> fBreakTarget;
718     TArray<SpvId> fContinueTarget;
719     bool fWroteRTFlip = false;
720     // holds variables synthesized during output, for lifetime purposes
721     SymbolTable fSynthetics{/*builtin=*/true};
722     // Holds a list of uniforms that were declared as globals at the top-level instead of in an
723     // interface block.
724     UniformBuffer fUniformBuffer;
725     std::vector<const VarDeclaration*> fTopLevelUniforms;
726     THashMap<const Variable*, int> fTopLevelUniformMap;  // <var, UniformBuffer field index>
727     SpvId fUniformBufferId = NA;
728 
729     friend class PointerLValue;
730     friend class SwizzleLValue;
731 };
732 
733 // Equality and hash operators for Instructions.
operator ==(const SPIRVCodeGenerator::Instruction & that) const734 bool SPIRVCodeGenerator::Instruction::operator==(const SPIRVCodeGenerator::Instruction& that) const {
735     return fOp         == that.fOp &&
736            fResultKind == that.fResultKind &&
737            fWords      == that.fWords;
738 }
739 
740 struct SPIRVCodeGenerator::Instruction::Hash {
operator ()SkSL::SPIRVCodeGenerator::Instruction::Hash741     uint32_t operator()(const SPIRVCodeGenerator::Instruction& key) const {
742         uint32_t hash = key.fResultKind;
743         hash = SkChecksum::Hash32(&key.fOp, sizeof(key.fOp), hash);
744         hash = SkChecksum::Hash32(key.fWords.data(), key.fWords.size() * sizeof(int32_t), hash);
745         return hash;
746     }
747 };
748 
749 // This class is used to pass values and result placeholder slots to writeInstruction.
750 struct SPIRVCodeGenerator::Word {
751     enum Kind {
752         kNone,  // intended for use as a sentinel, not part of any Instruction
753         kSpvId,
754         kNumber,
755         kDefaultPrecisionResult,
756         kRelaxedPrecisionResult,
757         kUniqueResult,
758         kKeyedResult,
759     };
760 
WordSkSL::SPIRVCodeGenerator::Word761     Word(SpvId id) : fValue(id), fKind(Kind::kSpvId) {}
WordSkSL::SPIRVCodeGenerator::Word762     Word(int32_t val, Kind kind) : fValue(val), fKind(kind) {}
763 
NumberSkSL::SPIRVCodeGenerator::Word764     static Word Number(int32_t val) {
765         return Word{val, Kind::kNumber};
766     }
767 
ResultSkSL::SPIRVCodeGenerator::Word768     static Word Result(const Type& type) {
769         return (type.hasPrecision() && !type.highPrecision()) ? RelaxedResult() : Result();
770     }
771 
RelaxedResultSkSL::SPIRVCodeGenerator::Word772     static Word RelaxedResult() {
773         return Word{(int32_t)NA, kRelaxedPrecisionResult};
774     }
775 
UniqueResultSkSL::SPIRVCodeGenerator::Word776     static Word UniqueResult() {
777         return Word{(int32_t)NA, kUniqueResult};
778     }
779 
ResultSkSL::SPIRVCodeGenerator::Word780     static Word Result() {
781         return Word{(int32_t)NA, kDefaultPrecisionResult};
782     }
783 
784     // Unlike a Result (where the result ID is always deduplicated to its first instruction) or a
785     // UniqueResult (which always produces a new instruction), a KeyedResult allows an instruction
786     // to be deduplicated among those that share the same `key`.
KeyedResultSkSL::SPIRVCodeGenerator::Word787     static Word KeyedResult(int32_t key) { return Word{key, Kind::kKeyedResult}; }
788 
isResultSkSL::SPIRVCodeGenerator::Word789     bool isResult() const { return fKind >= Kind::kDefaultPrecisionResult; }
790 
791     int32_t fValue;
792     Kind fKind;
793 };
794 
795 // Skia's magic number is 31 and goes in the top 16 bits. We can use the lower bits to version the
796 // sksl generator if we want.
797 // https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/spir-v.xml#L84
798 static const int32_t SKSL_MAGIC  = 0x001F0000;
799 
getIntrinsic(IntrinsicKind ik) const800 SPIRVCodeGenerator::Intrinsic SPIRVCodeGenerator::getIntrinsic(IntrinsicKind ik) const {
801 
802 #define ALL_GLSL(x) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, GLSLstd450 ## x, \
803                               GLSLstd450 ## x, GLSLstd450 ## x, GLSLstd450 ## x}
804 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, \
805                                                        GLSLstd450 ## ifFloat,             \
806                                                        GLSLstd450 ## ifInt,               \
807                                                        GLSLstd450 ## ifUInt,              \
808                                                        SpvOpUndef}
809 #define ALL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
810                                SpvOp ## x, SpvOp ## x, SpvOp ## x, SpvOp ## x}
811 #define BOOL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
812                                 SpvOpUndef, SpvOpUndef, SpvOpUndef, SpvOp ## x}
813 #define FLOAT_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
814                                  SpvOp ## x, SpvOpUndef, SpvOpUndef, SpvOpUndef}
815 #define SPECIAL(x) Intrinsic{kSpecial_IntrinsicOpcodeKind, k ## x ## _SpecialIntrinsic, \
816                              k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic,  \
817                              k ## x ## _SpecialIntrinsic}
818 
819     switch (ik) {
820         case k_round_IntrinsicKind:          return ALL_GLSL(Round);
821         case k_roundEven_IntrinsicKind:      return ALL_GLSL(RoundEven);
822         case k_trunc_IntrinsicKind:          return ALL_GLSL(Trunc);
823         case k_abs_IntrinsicKind:            return BY_TYPE_GLSL(FAbs, SAbs, SAbs);
824         case k_sign_IntrinsicKind:           return BY_TYPE_GLSL(FSign, SSign, SSign);
825         case k_floor_IntrinsicKind:          return ALL_GLSL(Floor);
826         case k_ceil_IntrinsicKind:           return ALL_GLSL(Ceil);
827         case k_fract_IntrinsicKind:          return ALL_GLSL(Fract);
828         case k_radians_IntrinsicKind:        return ALL_GLSL(Radians);
829         case k_degrees_IntrinsicKind:        return ALL_GLSL(Degrees);
830         case k_sin_IntrinsicKind:            return ALL_GLSL(Sin);
831         case k_cos_IntrinsicKind:            return ALL_GLSL(Cos);
832         case k_tan_IntrinsicKind:            return ALL_GLSL(Tan);
833         case k_asin_IntrinsicKind:           return ALL_GLSL(Asin);
834         case k_acos_IntrinsicKind:           return ALL_GLSL(Acos);
835         case k_atan_IntrinsicKind:           return SPECIAL(Atan);
836         case k_sinh_IntrinsicKind:           return ALL_GLSL(Sinh);
837         case k_cosh_IntrinsicKind:           return ALL_GLSL(Cosh);
838         case k_tanh_IntrinsicKind:           return ALL_GLSL(Tanh);
839         case k_asinh_IntrinsicKind:          return ALL_GLSL(Asinh);
840         case k_acosh_IntrinsicKind:          return ALL_GLSL(Acosh);
841         case k_atanh_IntrinsicKind:          return ALL_GLSL(Atanh);
842         case k_pow_IntrinsicKind:            return ALL_GLSL(Pow);
843         case k_exp_IntrinsicKind:            return ALL_GLSL(Exp);
844         case k_log_IntrinsicKind:            return ALL_GLSL(Log);
845         case k_exp2_IntrinsicKind:           return ALL_GLSL(Exp2);
846         case k_log2_IntrinsicKind:           return ALL_GLSL(Log2);
847         case k_sqrt_IntrinsicKind:           return ALL_GLSL(Sqrt);
848         case k_inverse_IntrinsicKind:        return ALL_GLSL(MatrixInverse);
849         case k_outerProduct_IntrinsicKind:   return ALL_SPIRV(OuterProduct);
850         case k_transpose_IntrinsicKind:      return ALL_SPIRV(Transpose);
851         case k_isinf_IntrinsicKind:          return ALL_SPIRV(IsInf);
852         case k_isnan_IntrinsicKind:          return ALL_SPIRV(IsNan);
853         case k_inversesqrt_IntrinsicKind:    return ALL_GLSL(InverseSqrt);
854         case k_determinant_IntrinsicKind:    return ALL_GLSL(Determinant);
855         case k_matrixCompMult_IntrinsicKind: return SPECIAL(MatrixCompMult);
856         case k_matrixInverse_IntrinsicKind:  return ALL_GLSL(MatrixInverse);
857         case k_mod_IntrinsicKind:            return SPECIAL(Mod);
858         case k_modf_IntrinsicKind:           return ALL_GLSL(Modf);
859         case k_min_IntrinsicKind:            return SPECIAL(Min);
860         case k_max_IntrinsicKind:            return SPECIAL(Max);
861         case k_clamp_IntrinsicKind:          return SPECIAL(Clamp);
862         case k_saturate_IntrinsicKind:       return SPECIAL(Saturate);
863         case k_dot_IntrinsicKind:            return FLOAT_SPIRV(Dot);
864         case k_mix_IntrinsicKind:            return SPECIAL(Mix);
865         case k_step_IntrinsicKind:           return SPECIAL(Step);
866         case k_smoothstep_IntrinsicKind:     return SPECIAL(SmoothStep);
867         case k_fma_IntrinsicKind:            return ALL_GLSL(Fma);
868         case k_frexp_IntrinsicKind:          return ALL_GLSL(Frexp);
869         case k_ldexp_IntrinsicKind:          return ALL_GLSL(Ldexp);
870 
871 #define PACK(type) case k_pack##type##_IntrinsicKind:   return ALL_GLSL(Pack##type); \
872                    case k_unpack##type##_IntrinsicKind: return ALL_GLSL(Unpack##type)
873         PACK(Snorm4x8);
874         PACK(Unorm4x8);
875         PACK(Snorm2x16);
876         PACK(Unorm2x16);
877         PACK(Half2x16);
878 #undef PACK
879 
880         case k_length_IntrinsicKind:        return ALL_GLSL(Length);
881         case k_distance_IntrinsicKind:      return ALL_GLSL(Distance);
882         case k_cross_IntrinsicKind:         return ALL_GLSL(Cross);
883         case k_normalize_IntrinsicKind:     return ALL_GLSL(Normalize);
884         case k_faceforward_IntrinsicKind:   return ALL_GLSL(FaceForward);
885         case k_reflect_IntrinsicKind:       return ALL_GLSL(Reflect);
886         case k_refract_IntrinsicKind:       return ALL_GLSL(Refract);
887         case k_bitCount_IntrinsicKind:      return ALL_SPIRV(BitCount);
888         case k_findLSB_IntrinsicKind:       return ALL_GLSL(FindILsb);
889         case k_findMSB_IntrinsicKind:       return BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
890         case k_dFdx_IntrinsicKind:          return FLOAT_SPIRV(DPdx);
891         case k_dFdy_IntrinsicKind:          return SPECIAL(DFdy);
892         case k_fwidth_IntrinsicKind:        return FLOAT_SPIRV(Fwidth);
893 
894         case k_sample_IntrinsicKind:      return SPECIAL(Texture);
895         case k_sampleGrad_IntrinsicKind:  return SPECIAL(TextureGrad);
896         case k_sampleLod_IntrinsicKind:   return SPECIAL(TextureLod);
897         case k_subpassLoad_IntrinsicKind: return SPECIAL(SubpassLoad);
898 
899         case k_textureRead_IntrinsicKind:  return SPECIAL(TextureRead);
900         case k_textureWrite_IntrinsicKind:  return SPECIAL(TextureWrite);
901         case k_textureWidth_IntrinsicKind:  return SPECIAL(TextureWidth);
902         case k_textureHeight_IntrinsicKind:  return SPECIAL(TextureHeight);
903 
904         case k_floatBitsToInt_IntrinsicKind:  return ALL_SPIRV(Bitcast);
905         case k_floatBitsToUint_IntrinsicKind: return ALL_SPIRV(Bitcast);
906         case k_intBitsToFloat_IntrinsicKind:  return ALL_SPIRV(Bitcast);
907         case k_uintBitsToFloat_IntrinsicKind: return ALL_SPIRV(Bitcast);
908 
909         case k_any_IntrinsicKind:   return BOOL_SPIRV(Any);
910         case k_all_IntrinsicKind:   return BOOL_SPIRV(All);
911         case k_not_IntrinsicKind:   return BOOL_SPIRV(LogicalNot);
912 
913         case k_equal_IntrinsicKind:
914             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
915                              SpvOpFOrdEqual,
916                              SpvOpIEqual,
917                              SpvOpIEqual,
918                              SpvOpLogicalEqual};
919         case k_notEqual_IntrinsicKind:
920             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
921                              SpvOpFUnordNotEqual,
922                              SpvOpINotEqual,
923                              SpvOpINotEqual,
924                              SpvOpLogicalNotEqual};
925         case k_lessThan_IntrinsicKind:
926             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
927                              SpvOpFOrdLessThan,
928                              SpvOpSLessThan,
929                              SpvOpULessThan,
930                              SpvOpUndef};
931         case k_lessThanEqual_IntrinsicKind:
932             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
933                              SpvOpFOrdLessThanEqual,
934                              SpvOpSLessThanEqual,
935                              SpvOpULessThanEqual,
936                              SpvOpUndef};
937         case k_greaterThan_IntrinsicKind:
938             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
939                              SpvOpFOrdGreaterThan,
940                              SpvOpSGreaterThan,
941                              SpvOpUGreaterThan,
942                              SpvOpUndef};
943         case k_greaterThanEqual_IntrinsicKind:
944             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
945                              SpvOpFOrdGreaterThanEqual,
946                              SpvOpSGreaterThanEqual,
947                              SpvOpUGreaterThanEqual,
948                              SpvOpUndef};
949 
950         case k_atomicAdd_IntrinsicKind:   return SPECIAL(AtomicAdd);
951         case k_atomicLoad_IntrinsicKind:  return SPECIAL(AtomicLoad);
952         case k_atomicStore_IntrinsicKind: return SPECIAL(AtomicStore);
953 
954         case k_storageBarrier_IntrinsicKind:   return SPECIAL(StorageBarrier);
955         case k_workgroupBarrier_IntrinsicKind: return SPECIAL(WorkgroupBarrier);
956 
957         default:
958             return Intrinsic{kInvalid_IntrinsicOpcodeKind, 0, 0, 0, 0};
959     }
960 }
961 
writeWord(int32_t word,OutputStream & out)962 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
963     out.write((const char*) &word, sizeof(word));
964 }
965 
is_float(const Type & type)966 static bool is_float(const Type& type) {
967     return (type.isScalar() || type.isVector() || type.isMatrix()) &&
968            type.componentType().isFloat();
969 }
970 
is_signed(const Type & type)971 static bool is_signed(const Type& type) {
972     return (type.isScalar() || type.isVector()) && type.componentType().isSigned();
973 }
974 
is_unsigned(const Type & type)975 static bool is_unsigned(const Type& type) {
976     return (type.isScalar() || type.isVector()) && type.componentType().isUnsigned();
977 }
978 
is_bool(const Type & type)979 static bool is_bool(const Type& type) {
980     return (type.isScalar() || type.isVector()) && type.componentType().isBoolean();
981 }
982 
983 template <typename T>
pick_by_type(const Type & type,T ifFloat,T ifInt,T ifUInt,T ifBool)984 static T pick_by_type(const Type& type, T ifFloat, T ifInt, T ifUInt, T ifBool) {
985     if (is_float(type)) {
986         return ifFloat;
987     }
988     if (is_signed(type)) {
989         return ifInt;
990     }
991     if (is_unsigned(type)) {
992         return ifUInt;
993     }
994     if (is_bool(type)) {
995         return ifBool;
996     }
997     SkDEBUGFAIL("unrecognized type");
998     return ifFloat;
999 }
1000 
is_out(ModifierFlags f)1001 static bool is_out(ModifierFlags f) {
1002     return SkToBool(f & ModifierFlag::kOut);
1003 }
1004 
is_in(ModifierFlags f)1005 static bool is_in(ModifierFlags f) {
1006     if (f & ModifierFlag::kIn) {
1007         return true;  // `in` and `inout` both count
1008     }
1009     // If neither in/out flag is set, the type is implicitly `in`.
1010     return !SkToBool(f & ModifierFlag::kOut);
1011 }
1012 
is_control_flow_op(SpvOp_ op)1013 static bool is_control_flow_op(SpvOp_ op) {
1014     switch (op) {
1015         case SpvOpReturn:
1016         case SpvOpReturnValue:
1017         case SpvOpKill:
1018         case SpvOpSwitch:
1019         case SpvOpBranch:
1020         case SpvOpBranchConditional:
1021             return true;
1022         default:
1023             return false;
1024     }
1025 }
1026 
is_globally_reachable_op(SpvOp_ op)1027 static bool is_globally_reachable_op(SpvOp_ op) {
1028     switch (op) {
1029         case SpvOpConstant:
1030         case SpvOpConstantTrue:
1031         case SpvOpConstantFalse:
1032         case SpvOpConstantComposite:
1033         case SpvOpTypeVoid:
1034         case SpvOpTypeInt:
1035         case SpvOpTypeFloat:
1036         case SpvOpTypeBool:
1037         case SpvOpTypeVector:
1038         case SpvOpTypeMatrix:
1039         case SpvOpTypeArray:
1040         case SpvOpTypePointer:
1041         case SpvOpTypeFunction:
1042         case SpvOpTypeRuntimeArray:
1043         case SpvOpTypeStruct:
1044         case SpvOpTypeImage:
1045         case SpvOpTypeSampledImage:
1046         case SpvOpTypeSampler:
1047         case SpvOpVariable:
1048         case SpvOpFunction:
1049         case SpvOpFunctionParameter:
1050         case SpvOpFunctionEnd:
1051         case SpvOpExecutionMode:
1052         case SpvOpMemoryModel:
1053         case SpvOpCapability:
1054         case SpvOpExtInstImport:
1055         case SpvOpEntryPoint:
1056         case SpvOpSource:
1057         case SpvOpSourceExtension:
1058         case SpvOpName:
1059         case SpvOpMemberName:
1060         case SpvOpDecorate:
1061         case SpvOpMemberDecorate:
1062             return true;
1063         default:
1064             return false;
1065     }
1066 }
1067 
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)1068 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
1069     SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
1070     SkASSERT(opCode != SpvOpUndef);
1071     bool foundDeadCode = false;
1072     if (is_control_flow_op(opCode)) {
1073         // This instruction causes us to leave the current block.
1074         foundDeadCode = (fCurrentBlock == 0);
1075         fCurrentBlock = 0;
1076     } else if (!is_globally_reachable_op(opCode)) {
1077         foundDeadCode = (fCurrentBlock == 0);
1078     }
1079 
1080     if (foundDeadCode) {
1081         // We just encountered dead code--an instruction that don't have an associated block.
1082         // Synthesize a label if this happens; this is necessary to satisfy the validator.
1083         this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
1084     }
1085 
1086     this->writeWord((length << 16) | opCode, out);
1087 }
1088 
writeLabel(SpvId label,StraightLineLabelType,OutputStream & out)1089 void SPIRVCodeGenerator::writeLabel(SpvId label, StraightLineLabelType, OutputStream& out) {
1090     // The straight-line label type is not important; in any case, no caches are invalidated.
1091     SkASSERT(!fCurrentBlock);
1092     fCurrentBlock = label;
1093     this->writeInstruction(SpvOpLabel, label, out);
1094 }
1095 
writeLabel(SpvId label,BranchingLabelType type,ConditionalOpCounts ops,OutputStream & out)1096 void SPIRVCodeGenerator::writeLabel(SpvId label, BranchingLabelType type,
1097                                     ConditionalOpCounts ops, OutputStream& out) {
1098     switch (type) {
1099         case kBranchIsBelow:
1100         case kBranchesOnBothSides:
1101             // With a backward or bidirectional branch, we haven't seen the code between the label
1102             // and the branch yet, so any stored value is potentially suspect. Without scanning
1103             // ahead to check, the only safe option is to ditch the store cache entirely.
1104             fStoreCache.reset();
1105             [[fallthrough]];
1106 
1107         case kBranchIsAbove:
1108             // With a forward branch, we can rely on stores that we had cached at the start of the
1109             // statement/expression, if they haven't been touched yet. Anything newer than that is
1110             // pruned.
1111             this->pruneConditionalOps(ops);
1112             break;
1113     }
1114 
1115     // Emit the label.
1116     this->writeLabel(label, kBranchlessBlock, out);
1117 }
1118 
writeInstruction(SpvOp_ opCode,OutputStream & out)1119 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
1120     this->writeOpCode(opCode, 1, out);
1121 }
1122 
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)1123 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
1124     this->writeOpCode(opCode, 2, out);
1125     this->writeWord(word1, out);
1126 }
1127 
writeString(std::string_view s,OutputStream & out)1128 void SPIRVCodeGenerator::writeString(std::string_view s, OutputStream& out) {
1129     out.write(s.data(), s.length());
1130     switch (s.length() % 4) {
1131         case 1:
1132             out.write8(0);
1133             [[fallthrough]];
1134         case 2:
1135             out.write8(0);
1136             [[fallthrough]];
1137         case 3:
1138             out.write8(0);
1139             break;
1140         default:
1141             this->writeWord(0, out);
1142             break;
1143     }
1144 }
1145 
writeInstruction(SpvOp_ opCode,std::string_view string,OutputStream & out)1146 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, std::string_view string,
1147                                           OutputStream& out) {
1148     this->writeOpCode(opCode, 1 + (string.length() + 4) / 4, out);
1149     this->writeString(string, out);
1150 }
1151 
writeInstruction(SpvOp_ opCode,int32_t word1,std::string_view string,OutputStream & out)1152 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
1153                                           OutputStream& out) {
1154     this->writeOpCode(opCode, 2 + (string.length() + 4) / 4, out);
1155     this->writeWord(word1, out);
1156     this->writeString(string, out);
1157 }
1158 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,std::string_view string,OutputStream & out)1159 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1160                                           std::string_view string, OutputStream& out) {
1161     this->writeOpCode(opCode, 3 + (string.length() + 4) / 4, out);
1162     this->writeWord(word1, out);
1163     this->writeWord(word2, out);
1164     this->writeString(string, out);
1165 }
1166 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)1167 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1168                                           OutputStream& out) {
1169     this->writeOpCode(opCode, 3, out);
1170     this->writeWord(word1, out);
1171     this->writeWord(word2, out);
1172 }
1173 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)1174 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1175                                           int32_t word3, OutputStream& out) {
1176     this->writeOpCode(opCode, 4, out);
1177     this->writeWord(word1, out);
1178     this->writeWord(word2, out);
1179     this->writeWord(word3, out);
1180 }
1181 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)1182 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1183                                           int32_t word3, int32_t word4, OutputStream& out) {
1184     this->writeOpCode(opCode, 5, out);
1185     this->writeWord(word1, out);
1186     this->writeWord(word2, out);
1187     this->writeWord(word3, out);
1188     this->writeWord(word4, out);
1189 }
1190 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)1191 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1192                                           int32_t word3, int32_t word4, int32_t word5,
1193                                           OutputStream& out) {
1194     this->writeOpCode(opCode, 6, out);
1195     this->writeWord(word1, out);
1196     this->writeWord(word2, out);
1197     this->writeWord(word3, out);
1198     this->writeWord(word4, out);
1199     this->writeWord(word5, out);
1200 }
1201 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)1202 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1203                                           int32_t word3, int32_t word4, int32_t word5,
1204                                           int32_t word6, OutputStream& out) {
1205     this->writeOpCode(opCode, 7, out);
1206     this->writeWord(word1, out);
1207     this->writeWord(word2, out);
1208     this->writeWord(word3, out);
1209     this->writeWord(word4, out);
1210     this->writeWord(word5, out);
1211     this->writeWord(word6, out);
1212 }
1213 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)1214 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1215                                           int32_t word3, int32_t word4, int32_t word5,
1216                                           int32_t word6, int32_t word7, OutputStream& out) {
1217     this->writeOpCode(opCode, 8, out);
1218     this->writeWord(word1, out);
1219     this->writeWord(word2, out);
1220     this->writeWord(word3, out);
1221     this->writeWord(word4, out);
1222     this->writeWord(word5, out);
1223     this->writeWord(word6, out);
1224     this->writeWord(word7, out);
1225 }
1226 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)1227 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1228                                           int32_t word3, int32_t word4, int32_t word5,
1229                                           int32_t word6, int32_t word7, int32_t word8,
1230                                           OutputStream& out) {
1231     this->writeOpCode(opCode, 9, out);
1232     this->writeWord(word1, out);
1233     this->writeWord(word2, out);
1234     this->writeWord(word3, out);
1235     this->writeWord(word4, out);
1236     this->writeWord(word5, out);
1237     this->writeWord(word6, out);
1238     this->writeWord(word7, out);
1239     this->writeWord(word8, out);
1240 }
1241 
BuildInstructionKey(SpvOp_ opCode,const TArray<Word> & words)1242 SPIRVCodeGenerator::Instruction SPIRVCodeGenerator::BuildInstructionKey(SpvOp_ opCode,
1243                                                                         const TArray<Word>& words) {
1244     // Assemble a cache key for this instruction.
1245     Instruction key;
1246     key.fOp = opCode;
1247     key.fWords.resize(words.size());
1248     key.fResultKind = Word::Kind::kNone;
1249 
1250     for (int index = 0; index < words.size(); ++index) {
1251         const Word& word = words[index];
1252         key.fWords[index] = word.fValue;
1253         if (word.isResult()) {
1254             SkASSERT(key.fResultKind == Word::Kind::kNone);
1255             key.fResultKind = word.fKind;
1256         }
1257     }
1258 
1259     return key;
1260 }
1261 
writeInstruction(SpvOp_ opCode,const TArray<Word> & words,OutputStream & out)1262 SpvId SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode,
1263                                            const TArray<Word>& words,
1264                                            OutputStream& out) {
1265     // writeOpLoad and writeOpStore have dedicated code.
1266     SkASSERT(opCode != SpvOpLoad);
1267     SkASSERT(opCode != SpvOpStore);
1268 
1269     // If this instruction exists in our op cache, return the cached SpvId.
1270     Instruction key = BuildInstructionKey(opCode, words);
1271     if (SpvId* cachedOp = fOpCache.find(key)) {
1272         return *cachedOp;
1273     }
1274 
1275     SpvId result = NA;
1276     Precision precision = Precision::kDefault;
1277 
1278     switch (key.fResultKind) {
1279         case Word::Kind::kUniqueResult:
1280             // The instruction returns a SpvId, but we do not want deduplication.
1281             result = this->nextId(Precision::kDefault);
1282             fSpvIdCache.set(result, key);
1283             break;
1284 
1285         case Word::Kind::kNone:
1286             // The instruction doesn't return a SpvId, but we can still cache and deduplicate it.
1287             fOpCache.set(key, result);
1288             break;
1289 
1290         case Word::Kind::kRelaxedPrecisionResult:
1291             precision = Precision::kRelaxed;
1292             [[fallthrough]];
1293 
1294         case Word::Kind::kKeyedResult:
1295             [[fallthrough]];
1296 
1297         case Word::Kind::kDefaultPrecisionResult:
1298             // Consume a new SpvId.
1299             result = this->nextId(precision);
1300             fOpCache.set(key, result);
1301             fSpvIdCache.set(result, key);
1302 
1303             // Globally-reachable ops are not subject to the whims of flow control.
1304             if (!is_globally_reachable_op(opCode)) {
1305                 fReachableOps.push_back(result);
1306             }
1307             break;
1308 
1309         default:
1310             SkDEBUGFAIL("unexpected result kind");
1311             break;
1312     }
1313 
1314     // Write the requested instruction.
1315     this->writeOpCode(opCode, words.size() + 1, out);
1316     for (const Word& word : words) {
1317         if (word.isResult()) {
1318             SkASSERT(result != NA);
1319             this->writeWord(result, out);
1320         } else {
1321             this->writeWord(word.fValue, out);
1322         }
1323     }
1324 
1325     // Return the result.
1326     return result;
1327 }
1328 
writeOpLoad(SpvId type,Precision precision,SpvId pointer,OutputStream & out)1329 SpvId SPIRVCodeGenerator::writeOpLoad(SpvId type,
1330                                       Precision precision,
1331                                       SpvId pointer,
1332                                       OutputStream& out) {
1333     // Look for this pointer in our load-cache.
1334     if (SpvId* cachedOp = fStoreCache.find(pointer)) {
1335         return *cachedOp;
1336     }
1337 
1338     // Write the requested OpLoad instruction.
1339     SpvId result = this->nextId(precision);
1340     this->writeInstruction(SpvOpLoad, type, result, pointer, out);
1341     return result;
1342 }
1343 
writeOpStore(StorageClass storageClass,SpvId pointer,SpvId value,OutputStream & out)1344 void SPIRVCodeGenerator::writeOpStore(StorageClass storageClass,
1345                                       SpvId pointer,
1346                                       SpvId value,
1347                                       OutputStream& out) {
1348     // Write the uncached SpvOpStore directly.
1349     this->writeInstruction(SpvOpStore, pointer, value, out);
1350 
1351     if (storageClass == StorageClass::kFunction) {
1352         // Insert a pointer-to-SpvId mapping into the load cache. A writeOpLoad to this pointer will
1353         // return the cached value as-is.
1354         fStoreCache.set(pointer, value);
1355         fStoreOps.push_back(pointer);
1356     }
1357 }
1358 
writeOpConstantTrue(const Type & type)1359 SpvId SPIRVCodeGenerator::writeOpConstantTrue(const Type& type) {
1360     return this->writeInstruction(SpvOpConstantTrue,
1361                                   Words{this->getType(type), Word::Result()},
1362                                   fConstantBuffer);
1363 }
1364 
writeOpConstantFalse(const Type & type)1365 SpvId SPIRVCodeGenerator::writeOpConstantFalse(const Type& type) {
1366     return this->writeInstruction(SpvOpConstantFalse,
1367                                   Words{this->getType(type), Word::Result()},
1368                                   fConstantBuffer);
1369 }
1370 
writeOpConstant(const Type & type,int32_t valueBits)1371 SpvId SPIRVCodeGenerator::writeOpConstant(const Type& type, int32_t valueBits) {
1372     return this->writeInstruction(
1373             SpvOpConstant,
1374             Words{this->getType(type), Word::Result(), Word::Number(valueBits)},
1375             fConstantBuffer);
1376 }
1377 
writeOpConstantComposite(const Type & type,const TArray<SpvId> & values)1378 SpvId SPIRVCodeGenerator::writeOpConstantComposite(const Type& type,
1379                                                    const TArray<SpvId>& values) {
1380     SkASSERT(values.size() == (type.isStruct() ? SkToInt(type.fields().size()) : type.columns()));
1381 
1382     Words words;
1383     words.push_back(this->getType(type));
1384     words.push_back(Word::Result());
1385     for (SpvId value : values) {
1386         words.push_back(value);
1387     }
1388     return this->writeInstruction(SpvOpConstantComposite, words, fConstantBuffer);
1389 }
1390 
toConstants(SpvId value,TArray<SpvId> * constants)1391 bool SPIRVCodeGenerator::toConstants(SpvId value, TArray<SpvId>* constants) {
1392     Instruction* instr = fSpvIdCache.find(value);
1393     if (!instr) {
1394         return false;
1395     }
1396     switch (instr->fOp) {
1397         case SpvOpConstant:
1398         case SpvOpConstantTrue:
1399         case SpvOpConstantFalse:
1400             constants->push_back(value);
1401             return true;
1402 
1403         case SpvOpConstantComposite: // OpConstantComposite ResultType ResultID Constituents...
1404             // Start at word 2 to skip past ResultType and ResultID.
1405             for (int i = 2; i < instr->fWords.size(); ++i) {
1406                 if (!this->toConstants(instr->fWords[i], constants)) {
1407                     return false;
1408                 }
1409             }
1410             return true;
1411 
1412         default:
1413             return false;
1414     }
1415 }
1416 
toConstants(SkSpan<const SpvId> values,TArray<SpvId> * constants)1417 bool SPIRVCodeGenerator::toConstants(SkSpan<const SpvId> values, TArray<SpvId>* constants) {
1418     for (SpvId value : values) {
1419         if (!this->toConstants(value, constants)) {
1420             return false;
1421         }
1422     }
1423     return true;
1424 }
1425 
writeOpCompositeConstruct(const Type & type,const TArray<SpvId> & values,OutputStream & out)1426 SpvId SPIRVCodeGenerator::writeOpCompositeConstruct(const Type& type,
1427                                                     const TArray<SpvId>& values,
1428                                                     OutputStream& out) {
1429     // If this is a vector composed entirely of literals, write a constant-composite instead.
1430     if (type.isVector()) {
1431         STArray<4, SpvId> constants;
1432         if (this->toConstants(SkSpan(values), &constants)) {
1433             // Create a vector from literals.
1434             return this->writeOpConstantComposite(type, constants);
1435         }
1436     }
1437 
1438     // If this is a matrix composed entirely of literals, constant-composite them instead.
1439     if (type.isMatrix()) {
1440         STArray<16, SpvId> constants;
1441         if (this->toConstants(SkSpan(values), &constants)) {
1442             // Create each matrix column.
1443             SkASSERT(type.isMatrix());
1444             const Type& vecType = type.columnType(fContext);
1445             STArray<4, SpvId> columnIDs;
1446             for (int index=0; index < type.columns(); ++index) {
1447                 STArray<4, SpvId> columnConstants(&constants[index * type.rows()],
1448                                                     type.rows());
1449                 columnIDs.push_back(this->writeOpConstantComposite(vecType, columnConstants));
1450             }
1451             // Compose the matrix from its columns.
1452             return this->writeOpConstantComposite(type, columnIDs);
1453         }
1454     }
1455 
1456     Words words;
1457     words.push_back(this->getType(type));
1458     words.push_back(Word::Result(type));
1459     for (SpvId value : values) {
1460         words.push_back(value);
1461     }
1462 
1463     return this->writeInstruction(SpvOpCompositeConstruct, words, out);
1464 }
1465 
resultTypeForInstruction(const Instruction & instr)1466 SPIRVCodeGenerator::Instruction* SPIRVCodeGenerator::resultTypeForInstruction(
1467         const Instruction& instr) {
1468     // This list should contain every op that we cache that has a result and result-type.
1469     // (If one is missing, we will not find some optimization opportunities.)
1470     // Generally, the result type of an op is in the 0th word, but I'm not sure if this is
1471     // universally true, so it's configurable on a per-op basis.
1472     int resultTypeWord;
1473     switch (instr.fOp) {
1474         case SpvOpConstant:
1475         case SpvOpConstantTrue:
1476         case SpvOpConstantFalse:
1477         case SpvOpConstantComposite:
1478         case SpvOpCompositeConstruct:
1479         case SpvOpCompositeExtract:
1480         case SpvOpLoad:
1481             resultTypeWord = 0;
1482             break;
1483 
1484         default:
1485             return nullptr;
1486     }
1487 
1488     Instruction* typeInstr = fSpvIdCache.find(instr.fWords[resultTypeWord]);
1489     SkASSERT(typeInstr);
1490     return typeInstr;
1491 }
1492 
numComponentsForVecInstruction(const Instruction & instr)1493 int SPIRVCodeGenerator::numComponentsForVecInstruction(const Instruction& instr) {
1494     // If an instruction is in the op cache, its type should be as well.
1495     Instruction* typeInstr = this->resultTypeForInstruction(instr);
1496     SkASSERT(typeInstr);
1497     SkASSERT(typeInstr->fOp == SpvOpTypeVector || typeInstr->fOp == SpvOpTypeFloat ||
1498              typeInstr->fOp == SpvOpTypeInt || typeInstr->fOp == SpvOpTypeBool);
1499 
1500     // For vectors, extract their column count. Scalars have one component by definition.
1501     //   SpvOpTypeVector ResultID ComponentType NumComponents
1502     return (typeInstr->fOp == SpvOpTypeVector) ? typeInstr->fWords[2]
1503                                                : 1;
1504 }
1505 
toComponent(SpvId id,int component)1506 SpvId SPIRVCodeGenerator::toComponent(SpvId id, int component) {
1507     Instruction* instr = fSpvIdCache.find(id);
1508     if (!instr) {
1509         return NA;
1510     }
1511     if (instr->fOp == SpvOpConstantComposite) {
1512         // SpvOpConstantComposite ResultType ResultID [components...]
1513         // Add 2 to the component index to skip past ResultType and ResultID.
1514         return instr->fWords[2 + component];
1515     }
1516     if (instr->fOp == SpvOpCompositeConstruct) {
1517         // SpvOpCompositeConstruct ResultType ResultID [components...]
1518         // Vectors have special rules; check to see if we are composing a vector.
1519         Instruction* composedType = fSpvIdCache.find(instr->fWords[0]);
1520         SkASSERT(composedType);
1521 
1522         // When composing a non-vector, each instruction word maps 1:1 to the component index.
1523         // We can just extract out the associated component directly.
1524         if (composedType->fOp != SpvOpTypeVector) {
1525             return instr->fWords[2 + component];
1526         }
1527 
1528         // When composing a vector, components can be either scalars or vectors.
1529         // This means we need to check the op type on each component. (+2 to skip ResultType/Result)
1530         for (int index = 2; index < instr->fWords.size(); ++index) {
1531             int32_t currentWord = instr->fWords[index];
1532 
1533             // Retrieve the sub-instruction pointed to by OpCompositeConstruct.
1534             Instruction* subinstr = fSpvIdCache.find(currentWord);
1535             if (!subinstr) {
1536                 return NA;
1537             }
1538             // If this subinstruction contains the component we're looking for...
1539             int numComponents = this->numComponentsForVecInstruction(*subinstr);
1540             if (component < numComponents) {
1541                 if (numComponents == 1) {
1542                     // ... it's a scalar. Return it.
1543                     SkASSERT(component == 0);
1544                     return currentWord;
1545                 } else {
1546                     // ... it's a vector. Recurse into it.
1547                     return this->toComponent(currentWord, component);
1548                 }
1549             }
1550             // This sub-instruction doesn't contain our component. Keep walking forward.
1551             component -= numComponents;
1552         }
1553         SkDEBUGFAIL("component index goes past the end of this composite value");
1554         return NA;
1555     }
1556     return NA;
1557 }
1558 
writeOpCompositeExtract(const Type & type,SpvId base,int component,OutputStream & out)1559 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
1560                                                   SpvId base,
1561                                                   int component,
1562                                                   OutputStream& out) {
1563     // If the base op is a composite, we can extract from it directly.
1564     SpvId result = this->toComponent(base, component);
1565     if (result != NA) {
1566         return result;
1567     }
1568     return this->writeInstruction(
1569             SpvOpCompositeExtract,
1570             {this->getType(type), Word::Result(type), base, Word::Number(component)},
1571             out);
1572 }
1573 
writeOpCompositeExtract(const Type & type,SpvId base,int componentA,int componentB,OutputStream & out)1574 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
1575                                                   SpvId base,
1576                                                   int componentA,
1577                                                   int componentB,
1578                                                   OutputStream& out) {
1579     // If the base op is a composite, we can extract from it directly.
1580     SpvId result = this->toComponent(base, componentA);
1581     if (result != NA) {
1582         return this->writeOpCompositeExtract(type, result, componentB, out);
1583     }
1584     return this->writeInstruction(SpvOpCompositeExtract,
1585                                   {this->getType(type),
1586                                    Word::Result(type),
1587                                    base,
1588                                    Word::Number(componentA),
1589                                    Word::Number(componentB)},
1590                                   out);
1591 }
1592 
writeCapabilities(OutputStream & out)1593 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
1594     for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
1595         if (fCapabilities & bit) {
1596             this->writeInstruction(SpvOpCapability, (SpvId) i, out);
1597         }
1598     }
1599     this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
1600 }
1601 
nextId(const Type * type)1602 SpvId SPIRVCodeGenerator::nextId(const Type* type) {
1603     return this->nextId(type && type->hasPrecision() && !type->highPrecision()
1604                 ? Precision::kRelaxed
1605                 : Precision::kDefault);
1606 }
1607 
nextId(Precision precision)1608 SpvId SPIRVCodeGenerator::nextId(Precision precision) {
1609     if (precision == Precision::kRelaxed && !fProgram.fConfig->fSettings.fForceHighPrecision) {
1610         this->writeInstruction(SpvOpDecorate, fIdCount, SpvDecorationRelaxedPrecision,
1611                                fDecorationBuffer);
1612     }
1613     return fIdCount++;
1614 }
1615 
writeStruct(const Type & type,const MemoryLayout & memoryLayout)1616 SpvId SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout) {
1617     // If we've already written out this struct, return its existing SpvId.
1618     if (SpvId* cachedStructId = fStructMap.find(&type)) {
1619         return *cachedStructId;
1620     }
1621 
1622     // Write all of the field types first, so we don't inadvertently write them while we're in the
1623     // middle of writing the struct instruction.
1624     Words words;
1625     words.push_back(Word::UniqueResult());
1626     for (const auto& f : type.fields()) {
1627         words.push_back(this->getType(*f.fType, f.fLayout, memoryLayout));
1628     }
1629     SpvId resultId = this->writeInstruction(SpvOpTypeStruct, words, fConstantBuffer);
1630     this->writeInstruction(SpvOpName, resultId, type.name(), fNameBuffer);
1631     fStructMap.set(&type, resultId);
1632 
1633     size_t offset = 0;
1634     for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
1635         const Field& field = type.fields()[i];
1636         if (!memoryLayout.isSupported(*field.fType)) {
1637             fContext.fErrors->error(type.fPosition, "type '" + field.fType->displayName() +
1638                                                     "' is not permitted here");
1639             return resultId;
1640         }
1641         size_t size = memoryLayout.size(*field.fType);
1642         size_t alignment = memoryLayout.alignment(*field.fType);
1643         const Layout& fieldLayout = field.fLayout;
1644         if (fieldLayout.fOffset >= 0) {
1645             if (fieldLayout.fOffset < (int) offset) {
1646                 fContext.fErrors->error(field.fPosition, "offset of field '" +
1647                         std::string(field.fName) + "' must be at least " + std::to_string(offset));
1648             }
1649             if (fieldLayout.fOffset % alignment) {
1650                 fContext.fErrors->error(field.fPosition,
1651                                         "offset of field '" + std::string(field.fName) +
1652                                         "' must be a multiple of " + std::to_string(alignment));
1653             }
1654             offset = fieldLayout.fOffset;
1655         } else {
1656             size_t mod = offset % alignment;
1657             if (mod) {
1658                 offset += alignment - mod;
1659             }
1660         }
1661         this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
1662         this->writeFieldLayout(fieldLayout, resultId, i);
1663         if (field.fLayout.fBuiltin < 0) {
1664             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
1665                                    (SpvId) offset, fDecorationBuffer);
1666         }
1667         if (field.fType->isMatrix()) {
1668             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1669                                    fDecorationBuffer);
1670             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1671                                    (SpvId) memoryLayout.stride(*field.fType),
1672                                    fDecorationBuffer);
1673         }
1674         if (!field.fType->highPrecision()) {
1675             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
1676                                    SpvDecorationRelaxedPrecision, fDecorationBuffer);
1677         }
1678         offset += size;
1679         if ((field.fType->isArray() || field.fType->isStruct()) && offset % alignment != 0) {
1680             offset += alignment - offset % alignment;
1681         }
1682     }
1683 
1684     return resultId;
1685 }
1686 
getType(const Type & type)1687 SpvId SPIRVCodeGenerator::getType(const Type& type) {
1688     return this->getType(type, kDefaultTypeLayout, fDefaultMemoryLayout);
1689 }
1690 
layout_flags_to_image_format(LayoutFlags flags)1691 static SpvImageFormat layout_flags_to_image_format(LayoutFlags flags) {
1692     flags &= LayoutFlag::kAllPixelFormats;
1693     switch (flags.value()) {
1694         case (int)LayoutFlag::kRGBA8:
1695             return SpvImageFormatRgba8;
1696 
1697         case (int)LayoutFlag::kRGBA32F:
1698             return SpvImageFormatRgba32f;
1699 
1700         case (int)LayoutFlag::kR32F:
1701             return SpvImageFormatR32f;
1702 
1703         default:
1704             return SpvImageFormatUnknown;
1705     }
1706 
1707     SkUNREACHABLE;
1708 }
1709 
getType(const Type & rawType,const Layout & typeLayout,const MemoryLayout & memoryLayout)1710 SpvId SPIRVCodeGenerator::getType(const Type& rawType,
1711                                   const Layout& typeLayout,
1712                                   const MemoryLayout& memoryLayout) {
1713     const Type* type = &rawType;
1714 
1715     switch (type->typeKind()) {
1716         case Type::TypeKind::kVoid: {
1717             return this->writeInstruction(SpvOpTypeVoid, Words{Word::Result()}, fConstantBuffer);
1718         }
1719         case Type::TypeKind::kScalar:
1720         case Type::TypeKind::kLiteral: {
1721             if (type->isBoolean()) {
1722                 return this->writeInstruction(SpvOpTypeBool, {Word::Result()}, fConstantBuffer);
1723             }
1724             if (type->isSigned()) {
1725                 return this->writeInstruction(
1726                         SpvOpTypeInt,
1727                         Words{Word::Result(), Word::Number(32), Word::Number(1)},
1728                         fConstantBuffer);
1729             }
1730             if (type->isUnsigned()) {
1731                 return this->writeInstruction(
1732                         SpvOpTypeInt,
1733                         Words{Word::Result(), Word::Number(32), Word::Number(0)},
1734                         fConstantBuffer);
1735             }
1736             if (type->isFloat()) {
1737                 return this->writeInstruction(
1738                         SpvOpTypeFloat,
1739                         Words{Word::Result(), Word::Number(32)},
1740                         fConstantBuffer);
1741             }
1742             SkDEBUGFAILF("unrecognized scalar type '%s'", type->description().c_str());
1743             return NA;
1744         }
1745         case Type::TypeKind::kVector: {
1746             SpvId scalarTypeId = this->getType(type->componentType(), typeLayout, memoryLayout);
1747             return this->writeInstruction(
1748                     SpvOpTypeVector,
1749                     Words{Word::Result(), scalarTypeId, Word::Number(type->columns())},
1750                     fConstantBuffer);
1751         }
1752         case Type::TypeKind::kMatrix: {
1753             SpvId vectorTypeId = this->getType(IndexExpression::IndexType(fContext, *type),
1754                                                typeLayout,
1755                                                memoryLayout);
1756             return this->writeInstruction(
1757                     SpvOpTypeMatrix,
1758                     Words{Word::Result(), vectorTypeId, Word::Number(type->columns())},
1759                     fConstantBuffer);
1760         }
1761         case Type::TypeKind::kArray: {
1762             const MemoryLayout arrayMemoryLayout =
1763                                     fCaps.fForceStd430ArrayLayout
1764                                         ? MemoryLayout(MemoryLayout::Standard::k430)
1765                                         : memoryLayout;
1766 
1767             if (!arrayMemoryLayout.isSupported(*type)) {
1768                 fContext.fErrors->error(type->fPosition, "type '" + type->displayName() +
1769                                                          "' is not permitted here");
1770                 return NA;
1771             }
1772             size_t stride = arrayMemoryLayout.stride(*type);
1773             SpvId typeId = this->getType(type->componentType(), typeLayout, arrayMemoryLayout);
1774             SpvId result = NA;
1775             if (type->isUnsizedArray()) {
1776                 result = this->writeInstruction(SpvOpTypeRuntimeArray,
1777                                                 Words{Word::KeyedResult(stride), typeId},
1778                                                 fConstantBuffer);
1779             } else {
1780                 SpvId countId = this->writeLiteral(type->columns(), *fContext.fTypes.fInt);
1781                 result = this->writeInstruction(SpvOpTypeArray,
1782                                                 Words{Word::KeyedResult(stride), typeId, countId},
1783                                                 fConstantBuffer);
1784             }
1785             this->writeInstruction(SpvOpDecorate,
1786                                    {result, SpvDecorationArrayStride, Word::Number(stride)},
1787                                    fDecorationBuffer);
1788             return result;
1789         }
1790         case Type::TypeKind::kStruct: {
1791             return this->writeStruct(*type, memoryLayout);
1792         }
1793         case Type::TypeKind::kSeparateSampler: {
1794             return this->writeInstruction(SpvOpTypeSampler, Words{Word::Result()}, fConstantBuffer);
1795         }
1796         case Type::TypeKind::kSampler: {
1797             if (SpvDimBuffer == type->dimensions()) {
1798                 fCapabilities |= 1ULL << SpvCapabilitySampledBuffer;
1799             }
1800             SpvId imageTypeId = this->getType(type->textureType(), typeLayout, memoryLayout);
1801             return this->writeInstruction(SpvOpTypeSampledImage,
1802                                           Words{Word::Result(), imageTypeId},
1803                                           fConstantBuffer);
1804         }
1805         case Type::TypeKind::kTexture: {
1806             SpvId floatTypeId = this->getType(*fContext.fTypes.fFloat,
1807                                               kDefaultTypeLayout,
1808                                               memoryLayout);
1809 
1810             bool sampled = (type->textureAccess() == Type::TextureAccess::kSample);
1811             SpvImageFormat format = (!sampled && type->dimensions() != SpvDimSubpassData)
1812                                             ? layout_flags_to_image_format(typeLayout.fFlags)
1813                                             : SpvImageFormatUnknown;
1814 
1815             return this->writeInstruction(SpvOpTypeImage,
1816                                           Words{Word::Result(),
1817                                                 floatTypeId,
1818                                                 Word::Number(type->dimensions()),
1819                                                 Word::Number(type->isDepth()),
1820                                                 Word::Number(type->isArrayedTexture()),
1821                                                 Word::Number(type->isMultisampled()),
1822                                                 Word::Number(sampled ? 1 : 2),
1823                                                 format},
1824                                           fConstantBuffer);
1825         }
1826         case Type::TypeKind::kAtomic: {
1827             // SkSL currently only supports the atomicUint type.
1828             SkASSERT(type->matches(*fContext.fTypes.fAtomicUInt));
1829             // SPIR-V doesn't have atomic types. Rather, it allows atomic operations on primitive
1830             // types. The SPIR-V type of an SkSL atomic is simply the underlying type.
1831             return this->writeInstruction(SpvOpTypeInt,
1832                                           Words{Word::Result(), Word::Number(32), Word::Number(0)},
1833                                           fConstantBuffer);
1834         }
1835         default: {
1836             SkDEBUGFAILF("invalid type: %s", type->description().c_str());
1837             return NA;
1838         }
1839     }
1840 }
1841 
getFunctionType(const FunctionDeclaration & function)1842 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
1843     Words words;
1844     words.push_back(Word::Result());
1845     words.push_back(this->getType(function.returnType()));
1846     for (const Variable* parameter : function.parameters()) {
1847         bool paramIsSpecialized = fActiveSpecialization && fActiveSpecialization->find(parameter);
1848         if (fUseTextureSamplerPairs && parameter->type().isSampler()) {
1849             words.push_back(this->getFunctionParameterType(parameter->type().textureType(),
1850                                                            parameter->layout()));
1851             if (!paramIsSpecialized) {
1852                 words.push_back(this->getFunctionParameterType(*fContext.fTypes.fSampler,
1853                                                                kDefaultTypeLayout));
1854             }
1855         } else if (!paramIsSpecialized) {
1856             words.push_back(this->getFunctionParameterType(parameter->type(), parameter->layout()));
1857         }
1858     }
1859     return this->writeInstruction(SpvOpTypeFunction, words, fConstantBuffer);
1860 }
1861 
getFunctionParameterType(const Type & parameterType,const Layout & parameterLayout)1862 SpvId SPIRVCodeGenerator::getFunctionParameterType(const Type& parameterType,
1863                                                    const Layout& parameterLayout) {
1864     // glslang treats all function arguments as pointers whether they need to be or
1865     // not. I was initially puzzled by this until I ran bizarre failures with certain
1866     // patterns of function calls and control constructs, as exemplified by this minimal
1867     // failure case:
1868     //
1869     // void sphere(float x) {
1870     // }
1871     //
1872     // void map() {
1873     //     sphere(1.0);
1874     // }
1875     //
1876     // void main() {
1877     //     for (int i = 0; i < 1; i++) {
1878     //         map();
1879     //     }
1880     // }
1881     //
1882     // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
1883     // crashes. Making it take a float* and storing the argument in a temporary variable,
1884     // as glslang does, fixes it.
1885     //
1886     // The consensus among shader compiler authors seems to be that GPU driver generally don't
1887     // handle value-based parameters consistently. It is highly likely that they fit their
1888     // implementations to conform to glslang. We take care to do so ourselves.
1889     //
1890     // Our implementation first stores every parameter value into a function storage-class pointer
1891     // before calling a function. The exception is for opaque handle types (samplers and textures)
1892     // which must be stored in a pointer with UniformConstant storage-class. This prevents
1893     // unnecessary temporaries (becuase opaque handles are always rooted in a pointer variable),
1894     // matches glslang's behavior, and translates into WGSL more easily when targeting Dawn.
1895     StorageClass storageClass;
1896     if (parameterType.typeKind() == Type::TypeKind::kSampler ||
1897         parameterType.typeKind() == Type::TypeKind::kSeparateSampler ||
1898         parameterType.typeKind() == Type::TypeKind::kTexture) {
1899         storageClass = StorageClass::kUniformConstant;
1900     } else {
1901         storageClass = StorageClass::kFunction;
1902     }
1903     return this->getPointerType(parameterType,
1904                                 parameterLayout,
1905                                 this->memoryLayoutForStorageClass(storageClass),
1906                                 storageClass);
1907 }
1908 
getPointerType(const Type & type,StorageClass storageClass)1909 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, StorageClass storageClass) {
1910     return this->getPointerType(type,
1911                                 kDefaultTypeLayout,
1912                                 this->memoryLayoutForStorageClass(storageClass),
1913                                 storageClass);
1914 }
1915 
getPointerType(const Type & type,const Layout & typeLayout,const MemoryLayout & memoryLayout,StorageClass storageClass)1916 SpvId SPIRVCodeGenerator::getPointerType(const Type& type,
1917                                          const Layout& typeLayout,
1918                                          const MemoryLayout& memoryLayout,
1919                                          StorageClass storageClass) {
1920     return this->writeInstruction(SpvOpTypePointer,
1921                                   Words{Word::Result(),
1922                                         Word::Number(get_storage_class_spv_id(storageClass)),
1923                                         this->getType(type, typeLayout, memoryLayout)},
1924                                   fConstantBuffer);
1925 }
1926 
writeExpression(const Expression & expr,OutputStream & out)1927 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
1928     switch (expr.kind()) {
1929         case Expression::Kind::kBinary:
1930             return this->writeBinaryExpression(expr.as<BinaryExpression>(), out);
1931         case Expression::Kind::kConstructorArrayCast:
1932             return this->writeExpression(*expr.as<ConstructorArrayCast>().argument(), out);
1933         case Expression::Kind::kConstructorArray:
1934         case Expression::Kind::kConstructorStruct:
1935             return this->writeCompositeConstructor(expr.asAnyConstructor(), out);
1936         case Expression::Kind::kConstructorDiagonalMatrix:
1937             return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
1938         case Expression::Kind::kConstructorMatrixResize:
1939             return this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(), out);
1940         case Expression::Kind::kConstructorScalarCast:
1941             return this->writeConstructorScalarCast(expr.as<ConstructorScalarCast>(), out);
1942         case Expression::Kind::kConstructorSplat:
1943             return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out);
1944         case Expression::Kind::kConstructorCompound:
1945             return this->writeConstructorCompound(expr.as<ConstructorCompound>(), out);
1946         case Expression::Kind::kConstructorCompoundCast:
1947             return this->writeConstructorCompoundCast(expr.as<ConstructorCompoundCast>(), out);
1948         case Expression::Kind::kEmpty:
1949             return NA;
1950         case Expression::Kind::kFieldAccess:
1951             return this->writeFieldAccess(expr.as<FieldAccess>(), out);
1952         case Expression::Kind::kFunctionCall:
1953             return this->writeFunctionCall(expr.as<FunctionCall>(), out);
1954         case Expression::Kind::kLiteral:
1955             return this->writeLiteral(expr.as<Literal>());
1956         case Expression::Kind::kPrefix:
1957             return this->writePrefixExpression(expr.as<PrefixExpression>(), out);
1958         case Expression::Kind::kPostfix:
1959             return this->writePostfixExpression(expr.as<PostfixExpression>(), out);
1960         case Expression::Kind::kSwizzle:
1961             return this->writeSwizzle(expr.as<Swizzle>(), out);
1962         case Expression::Kind::kVariableReference:
1963             return this->writeVariableReference(expr.as<VariableReference>(), out);
1964         case Expression::Kind::kTernary:
1965             return this->writeTernaryExpression(expr.as<TernaryExpression>(), out);
1966         case Expression::Kind::kIndex:
1967             return this->writeIndexExpression(expr.as<IndexExpression>(), out);
1968         case Expression::Kind::kSetting:
1969             return this->writeExpression(*expr.as<Setting>().toLiteral(fCaps), out);
1970         default:
1971             SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
1972             break;
1973     }
1974     return NA;
1975 }
1976 
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)1977 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
1978     const FunctionDeclaration& function = c.function();
1979     Intrinsic intrinsic = this->getIntrinsic(function.intrinsicKind());
1980     if (intrinsic.opKind == kInvalid_IntrinsicOpcodeKind) {
1981         fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" + function.description() +
1982                 "'");
1983         return NA;
1984     }
1985     const ExpressionArray& arguments = c.arguments();
1986     int32_t intrinsicId = intrinsic.floatOp;
1987     if (!arguments.empty()) {
1988         const Type& type = arguments[0]->type();
1989         if (intrinsic.opKind == kSpecial_IntrinsicOpcodeKind) {
1990             // Keep the default float op.
1991         } else {
1992             intrinsicId = pick_by_type(type, intrinsic.floatOp, intrinsic.signedOp,
1993                                        intrinsic.unsignedOp, intrinsic.boolOp);
1994         }
1995     }
1996     switch (intrinsic.opKind) {
1997         case kGLSL_STD_450_IntrinsicOpcodeKind: {
1998             SpvId result = this->nextId(&c.type());
1999             TArray<SpvId> argumentIds;
2000             argumentIds.reserve_exact(arguments.size());
2001             std::vector<TempVar> tempVars;
2002             for (int i = 0; i < arguments.size(); i++) {
2003                 this->writeFunctionCallArgument(argumentIds, c, i, &tempVars,
2004                                                 /*specializedParams=*/nullptr, out);
2005             }
2006             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
2007             this->writeWord(this->getType(c.type()), out);
2008             this->writeWord(result, out);
2009             this->writeWord(fGLSLExtendedInstructions, out);
2010             this->writeWord(intrinsicId, out);
2011             for (SpvId id : argumentIds) {
2012                 this->writeWord(id, out);
2013             }
2014             this->copyBackTempVars(tempVars, out);
2015             return result;
2016         }
2017         case kSPIRV_IntrinsicOpcodeKind: {
2018             // GLSL supports dot(float, float), but SPIR-V does not. Convert it to FMul
2019             if (intrinsicId == SpvOpDot && arguments[0]->type().isScalar()) {
2020                 intrinsicId = SpvOpFMul;
2021             }
2022             SpvId result = this->nextId(&c.type());
2023             TArray<SpvId> argumentIds;
2024             argumentIds.reserve_exact(arguments.size());
2025             std::vector<TempVar> tempVars;
2026             for (int i = 0; i < arguments.size(); i++) {
2027                 this->writeFunctionCallArgument(argumentIds, c, i, &tempVars,
2028                                                 /*specializedParams=*/nullptr, out);
2029             }
2030             if (!c.type().isVoid()) {
2031                 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
2032                 this->writeWord(this->getType(c.type()), out);
2033                 this->writeWord(result, out);
2034             } else {
2035                 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
2036             }
2037             for (SpvId id : argumentIds) {
2038                 this->writeWord(id, out);
2039             }
2040             this->copyBackTempVars(tempVars, out);
2041             return result;
2042         }
2043         case kSpecial_IntrinsicOpcodeKind:
2044             return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
2045         default:
2046             fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" +
2047                     function.description() + "'");
2048             return NA;
2049     }
2050 }
2051 
vectorize(const Expression & arg,int vectorSize,OutputStream & out)2052 SpvId SPIRVCodeGenerator::vectorize(const Expression& arg, int vectorSize, OutputStream& out) {
2053     SkASSERT(vectorSize >= 1 && vectorSize <= 4);
2054     const Type& argType = arg.type();
2055     if (argType.isScalar() && vectorSize > 1) {
2056         SpvId argID = this->writeExpression(arg, out);
2057         return this->splat(argType.toCompound(fContext, vectorSize, /*rows=*/1), argID, out);
2058     }
2059 
2060     SkASSERT(vectorSize == argType.columns());
2061     return this->writeExpression(arg, out);
2062 }
2063 
vectorize(const ExpressionArray & args,OutputStream & out)2064 TArray<SpvId> SPIRVCodeGenerator::vectorize(const ExpressionArray& args, OutputStream& out) {
2065     int vectorSize = 1;
2066     for (const auto& a : args) {
2067         if (a->type().isVector()) {
2068             if (vectorSize > 1) {
2069                 SkASSERT(a->type().columns() == vectorSize);
2070             } else {
2071                 vectorSize = a->type().columns();
2072             }
2073         }
2074     }
2075     TArray<SpvId> result;
2076     result.reserve_exact(args.size());
2077     for (const auto& arg : args) {
2078         result.push_back(this->vectorize(*arg, vectorSize, out));
2079     }
2080     return result;
2081 }
2082 
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const TArray<SpvId> & args,OutputStream & out)2083 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
2084                                                       SpvId signedInst, SpvId unsignedInst,
2085                                                       const TArray<SpvId>& args,
2086                                                       OutputStream& out) {
2087     this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
2088     this->writeWord(this->getType(type), out);
2089     this->writeWord(id, out);
2090     this->writeWord(fGLSLExtendedInstructions, out);
2091     this->writeWord(pick_by_type(type, floatInst, signedInst, unsignedInst, NA), out);
2092     for (SpvId a : args) {
2093         this->writeWord(a, out);
2094     }
2095 }
2096 
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)2097 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
2098                                                 OutputStream& out) {
2099     const ExpressionArray& arguments = c.arguments();
2100     const Type& callType = c.type();
2101     SpvId result = this->nextId(nullptr);
2102     switch (kind) {
2103         case kAtan_SpecialIntrinsic: {
2104             STArray<2, SpvId> argumentIds;
2105             for (const std::unique_ptr<Expression>& arg : arguments) {
2106                 argumentIds.push_back(this->writeExpression(*arg, out));
2107             }
2108             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
2109             this->writeWord(this->getType(callType), out);
2110             this->writeWord(result, out);
2111             this->writeWord(fGLSLExtendedInstructions, out);
2112             this->writeWord(argumentIds.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
2113             for (SpvId id : argumentIds) {
2114                 this->writeWord(id, out);
2115             }
2116             break;
2117         }
2118         case kSampledImage_SpecialIntrinsic: {
2119             SkASSERT(arguments.size() == 2);
2120             SpvId img = this->writeExpression(*arguments[0], out);
2121             SpvId sampler = this->writeExpression(*arguments[1], out);
2122             this->writeInstruction(SpvOpSampledImage,
2123                                    this->getType(callType),
2124                                    result,
2125                                    img,
2126                                    sampler,
2127                                    out);
2128             break;
2129         }
2130         case kSubpassLoad_SpecialIntrinsic: {
2131             SpvId img = this->writeExpression(*arguments[0], out);
2132             ExpressionArray args;
2133             args.reserve_exact(2);
2134             args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
2135             args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
2136             ConstructorCompound ctor(Position(), *fContext.fTypes.fInt2, std::move(args));
2137             SpvId coords = this->writeExpression(ctor, out);
2138             if (arguments.size() == 1) {
2139                 this->writeInstruction(SpvOpImageRead,
2140                                        this->getType(callType),
2141                                        result,
2142                                        img,
2143                                        coords,
2144                                        out);
2145             } else {
2146                 SkASSERT(arguments.size() == 2);
2147                 SpvId sample = this->writeExpression(*arguments[1], out);
2148                 this->writeInstruction(SpvOpImageRead,
2149                                        this->getType(callType),
2150                                        result,
2151                                        img,
2152                                        coords,
2153                                        SpvImageOperandsSampleMask,
2154                                        sample,
2155                                        out);
2156             }
2157             break;
2158         }
2159         case kTexture_SpecialIntrinsic: {
2160             SpvOp_ op = SpvOpImageSampleImplicitLod;
2161             const Type& arg1Type = arguments[1]->type();
2162             switch (arguments[0]->type().dimensions()) {
2163                 case SpvDim1D:
2164                     if (arg1Type.matches(*fContext.fTypes.fFloat2)) {
2165                         op = SpvOpImageSampleProjImplicitLod;
2166                     } else {
2167                         SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat));
2168                     }
2169                     break;
2170                 case SpvDim2D:
2171                     if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
2172                         op = SpvOpImageSampleProjImplicitLod;
2173                     } else {
2174                         SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
2175                     }
2176                     break;
2177                 case SpvDim3D:
2178                     if (arg1Type.matches(*fContext.fTypes.fFloat4)) {
2179                         op = SpvOpImageSampleProjImplicitLod;
2180                     } else {
2181                         SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat3));
2182                     }
2183                     break;
2184                 case SpvDimCube:   // fall through
2185                 case SpvDimRect:   // fall through
2186                 case SpvDimBuffer: // fall through
2187                 case SpvDimSubpassData:
2188                     break;
2189             }
2190             SpvId type = this->getType(callType);
2191             SpvId sampler = this->writeExpression(*arguments[0], out);
2192             SpvId uv = this->writeExpression(*arguments[1], out);
2193             if (arguments.size() == 3) {
2194                 this->writeInstruction(op, type, result, sampler, uv,
2195                                        SpvImageOperandsBiasMask,
2196                                        this->writeExpression(*arguments[2], out),
2197                                        out);
2198             } else {
2199                 SkASSERT(arguments.size() == 2);
2200                 if (fProgram.fConfig->fSettings.fSharpenTextures) {
2201                     SpvId lodBias = this->writeLiteral(kSharpenTexturesBias,
2202                                                        *fContext.fTypes.fFloat);
2203                     this->writeInstruction(op, type, result, sampler, uv,
2204                                            SpvImageOperandsBiasMask, lodBias, out);
2205                 } else {
2206                     this->writeInstruction(op, type, result, sampler, uv,
2207                                            out);
2208                 }
2209             }
2210             break;
2211         }
2212         case kTextureGrad_SpecialIntrinsic: {
2213             SpvOp_ op = SpvOpImageSampleExplicitLod;
2214             SkASSERT(arguments.size() == 4);
2215             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2216             SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fFloat2));
2217             SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat2));
2218             SkASSERT(arguments[3]->type().matches(*fContext.fTypes.fFloat2));
2219             SpvId type = this->getType(callType);
2220             SpvId sampler = this->writeExpression(*arguments[0], out);
2221             SpvId uv = this->writeExpression(*arguments[1], out);
2222             SpvId dPdx = this->writeExpression(*arguments[2], out);
2223             SpvId dPdy = this->writeExpression(*arguments[3], out);
2224             this->writeInstruction(op, type, result, sampler, uv, SpvImageOperandsGradMask,
2225                                    dPdx, dPdy, out);
2226             break;
2227         }
2228         case kTextureLod_SpecialIntrinsic: {
2229             SpvOp_ op = SpvOpImageSampleExplicitLod;
2230             SkASSERT(arguments.size() == 3);
2231             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2232             SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat));
2233             const Type& arg1Type = arguments[1]->type();
2234             if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
2235                 op = SpvOpImageSampleProjExplicitLod;
2236             } else {
2237                 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
2238             }
2239             SpvId type = this->getType(callType);
2240             SpvId sampler = this->writeExpression(*arguments[0], out);
2241             SpvId uv = this->writeExpression(*arguments[1], out);
2242             this->writeInstruction(op, type, result, sampler, uv,
2243                                    SpvImageOperandsLodMask,
2244                                    this->writeExpression(*arguments[2], out),
2245                                    out);
2246             break;
2247         }
2248         case kTextureRead_SpecialIntrinsic: {
2249             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2250             SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fUInt2));
2251 
2252             SpvId type = this->getType(callType);
2253             SpvId image = this->writeExpression(*arguments[0], out);
2254             SpvId coord = this->writeExpression(*arguments[1], out);
2255 
2256             const Type& arg0Type = arguments[0]->type();
2257             SkASSERT(arg0Type.typeKind() == Type::TypeKind::kTexture);
2258 
2259             switch (arg0Type.textureAccess()) {
2260                 case Type::TextureAccess::kSample:
2261                     this->writeInstruction(SpvOpImageFetch, type, result, image, coord,
2262                                            SpvImageOperandsLodMask,
2263                                            this->writeOpConstant(*fContext.fTypes.fInt, 0),
2264                                            out);
2265                     break;
2266                 case Type::TextureAccess::kRead:
2267                 case Type::TextureAccess::kReadWrite:
2268                     this->writeInstruction(SpvOpImageRead, type, result, image, coord, out);
2269                     break;
2270                 case Type::TextureAccess::kWrite:
2271                 default:
2272                     SkDEBUGFAIL("'textureRead' called on writeonly texture type");
2273                     break;
2274             }
2275 
2276             break;
2277         }
2278         case kTextureWrite_SpecialIntrinsic: {
2279             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2280             SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fUInt2));
2281             SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fHalf4));
2282 
2283             SpvId image = this->writeExpression(*arguments[0], out);
2284             SpvId coord = this->writeExpression(*arguments[1], out);
2285             SpvId texel = this->writeExpression(*arguments[2], out);
2286 
2287             this->writeInstruction(SpvOpImageWrite, image, coord, texel, out);
2288             break;
2289         }
2290         case kTextureWidth_SpecialIntrinsic:
2291         case kTextureHeight_SpecialIntrinsic: {
2292             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2293             fCapabilities |= 1ULL << SpvCapabilityImageQuery;
2294 
2295             SpvId dimsType = this->getType(*fContext.fTypes.fUInt2);
2296             SpvId dims = this->nextId(nullptr);
2297             SpvId image = this->writeExpression(*arguments[0], out);
2298             this->writeInstruction(SpvOpImageQuerySize, dimsType, dims, image, out);
2299 
2300             SpvId type = this->getType(callType);
2301             int32_t index = (kind == kTextureWidth_SpecialIntrinsic) ? 0 : 1;
2302             this->writeInstruction(SpvOpCompositeExtract, type, result, dims, index, out);
2303             break;
2304         }
2305         case kMod_SpecialIntrinsic: {
2306             TArray<SpvId> args = this->vectorize(arguments, out);
2307             SkASSERT(args.size() == 2);
2308             const Type& operandType = arguments[0]->type();
2309             SpvOp_ op = pick_by_type(operandType, SpvOpFMod, SpvOpSMod, SpvOpUMod, SpvOpUndef);
2310             SkASSERT(op != SpvOpUndef);
2311             this->writeOpCode(op, 5, out);
2312             this->writeWord(this->getType(operandType), out);
2313             this->writeWord(result, out);
2314             this->writeWord(args[0], out);
2315             this->writeWord(args[1], out);
2316             break;
2317         }
2318         case kDFdy_SpecialIntrinsic: {
2319             SpvId fn = this->writeExpression(*arguments[0], out);
2320             this->writeOpCode(SpvOpDPdy, 4, out);
2321             this->writeWord(this->getType(callType), out);
2322             this->writeWord(result, out);
2323             this->writeWord(fn, out);
2324             if (!fProgram.fConfig->fSettings.fForceNoRTFlip) {
2325                 this->addRTFlipUniform(c.fPosition);
2326                 ComponentArray componentArray;
2327                 for (int index = 0; index < callType.columns(); ++index) {
2328                     componentArray.push_back(SwizzleComponent::Y);
2329                 }
2330                 SpvId rtFlipY = this->writeSwizzle(*this->identifier(SKSL_RTFLIP_NAME),
2331                                                    componentArray, out);
2332                 SpvId flipped = this->nextId(&callType);
2333                 this->writeInstruction(SpvOpFMul, this->getType(callType), flipped, result,
2334                                        rtFlipY, out);
2335                 result = flipped;
2336             }
2337             break;
2338         }
2339         case kClamp_SpecialIntrinsic: {
2340             TArray<SpvId> args = this->vectorize(arguments, out);
2341             SkASSERT(args.size() == 3);
2342             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
2343                                                GLSLstd450UClamp, args, out);
2344             break;
2345         }
2346         case kMax_SpecialIntrinsic: {
2347             TArray<SpvId> args = this->vectorize(arguments, out);
2348             SkASSERT(args.size() == 2);
2349             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMax, GLSLstd450SMax,
2350                                                GLSLstd450UMax, args, out);
2351             break;
2352         }
2353         case kMin_SpecialIntrinsic: {
2354             TArray<SpvId> args = this->vectorize(arguments, out);
2355             SkASSERT(args.size() == 2);
2356             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMin, GLSLstd450SMin,
2357                                                GLSLstd450UMin, args, out);
2358             break;
2359         }
2360         case kMix_SpecialIntrinsic: {
2361             TArray<SpvId> args = this->vectorize(arguments, out);
2362             SkASSERT(args.size() == 3);
2363             if (arguments[2]->type().componentType().isBoolean()) {
2364                 // Use OpSelect to implement Boolean mix().
2365                 SpvId falseId     = this->writeExpression(*arguments[0], out);
2366                 SpvId trueId      = this->writeExpression(*arguments[1], out);
2367                 SpvId conditionId = this->writeExpression(*arguments[2], out);
2368                 this->writeInstruction(SpvOpSelect, this->getType(arguments[0]->type()), result,
2369                                        conditionId, trueId, falseId, out);
2370             } else {
2371                 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMix, SpvOpUndef,
2372                                                    SpvOpUndef, args, out);
2373             }
2374             break;
2375         }
2376         case kSaturate_SpecialIntrinsic: {
2377             SkASSERT(arguments.size() == 1);
2378             int width = arguments[0]->type().columns();
2379             STArray<3, SpvId> spvArgs{
2380                 this->vectorize(*arguments[0], width, out),
2381                 this->vectorize(*Literal::MakeFloat(fContext, Position(), /*value=*/0), width, out),
2382                 this->vectorize(*Literal::MakeFloat(fContext, Position(), /*value=*/1), width, out),
2383             };
2384             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
2385                                                GLSLstd450UClamp, spvArgs, out);
2386             break;
2387         }
2388         case kSmoothStep_SpecialIntrinsic: {
2389             TArray<SpvId> args = this->vectorize(arguments, out);
2390             SkASSERT(args.size() == 3);
2391             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450SmoothStep, SpvOpUndef,
2392                                                SpvOpUndef, args, out);
2393             break;
2394         }
2395         case kStep_SpecialIntrinsic: {
2396             TArray<SpvId> args = this->vectorize(arguments, out);
2397             SkASSERT(args.size() == 2);
2398             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450Step, SpvOpUndef,
2399                                                SpvOpUndef, args, out);
2400             break;
2401         }
2402         case kMatrixCompMult_SpecialIntrinsic: {
2403             SkASSERT(arguments.size() == 2);
2404             SpvId lhs = this->writeExpression(*arguments[0], out);
2405             SpvId rhs = this->writeExpression(*arguments[1], out);
2406             result = this->writeComponentwiseMatrixBinary(callType, lhs, rhs, SpvOpFMul, out);
2407             break;
2408         }
2409         case kAtomicAdd_SpecialIntrinsic:
2410         case kAtomicLoad_SpecialIntrinsic:
2411         case kAtomicStore_SpecialIntrinsic:
2412             result = this->writeAtomicIntrinsic(c, kind, result, out);
2413             break;
2414         case kStorageBarrier_SpecialIntrinsic:
2415         case kWorkgroupBarrier_SpecialIntrinsic: {
2416             // Both barrier types operate in the workgroup execution and memory scope and differ
2417             // only in memory semantics. storageBarrier() is not a device-scope barrier.
2418             SpvId scopeId =
2419                     this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)SpvScopeWorkgroup);
2420             int32_t memSemMask = (kind == kStorageBarrier_SpecialIntrinsic)
2421                                          ? SpvMemorySemanticsAcquireReleaseMask |
2422                                                    SpvMemorySemanticsUniformMemoryMask
2423                                          : SpvMemorySemanticsAcquireReleaseMask |
2424                                                    SpvMemorySemanticsWorkgroupMemoryMask;
2425             SpvId memorySemanticsId = this->writeOpConstant(*fContext.fTypes.fUInt, memSemMask);
2426             this->writeInstruction(SpvOpControlBarrier,
2427                                    scopeId,  // execution scope
2428                                    scopeId,  // memory scope
2429                                    memorySemanticsId,
2430                                    out);
2431             break;
2432         }
2433     }
2434     return result;
2435 }
2436 
writeAtomicIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,SpvId resultId,OutputStream & out)2437 SpvId SPIRVCodeGenerator::writeAtomicIntrinsic(const FunctionCall& c,
2438                                                SpecialIntrinsic kind,
2439                                                SpvId resultId,
2440                                                OutputStream& out) {
2441     const ExpressionArray& arguments = c.arguments();
2442     SkASSERT(!arguments.empty());
2443 
2444     std::unique_ptr<LValue> atomicPtr = this->getLValue(*arguments[0], out);
2445     SpvId atomicPtrId = atomicPtr->getPointer();
2446     if (atomicPtrId == NA) {
2447         SkDEBUGFAILF("atomic intrinsic expected a pointer argument: %s",
2448                      arguments[0]->description().c_str());
2449         return NA;
2450     }
2451 
2452     SpvId memoryScopeId = NA;
2453     {
2454         // In SkSL, the atomicUint type can only be declared as a workgroup variable or SSBO block
2455         // member. The two memory scopes that these map to are "workgroup" and "device",
2456         // respectively.
2457         SpvScope memoryScope;
2458         switch (atomicPtr->storageClass()) {
2459             case StorageClass::kUniform:
2460             case StorageClass::kStorageBuffer:
2461                 // We encode storage buffers in the uniform address space (with the BufferBlock
2462                 // decorator).
2463                 memoryScope = SpvScopeDevice;
2464                 break;
2465             case StorageClass::kWorkgroup:
2466                 memoryScope = SpvScopeWorkgroup;
2467                 break;
2468             default:
2469                 SkDEBUGFAILF("atomic argument has invalid storage class: %d",
2470                              get_storage_class_spv_id(atomicPtr->storageClass()));
2471                 return NA;
2472         }
2473         memoryScopeId = this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)memoryScope);
2474     }
2475 
2476     SpvId relaxedMemoryOrderId =
2477             this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)SpvMemorySemanticsMaskNone);
2478 
2479     switch (kind) {
2480         case kAtomicAdd_SpecialIntrinsic:
2481             SkASSERT(arguments.size() == 2);
2482             this->writeInstruction(SpvOpAtomicIAdd,
2483                                    this->getType(c.type()),
2484                                    resultId,
2485                                    atomicPtrId,
2486                                    memoryScopeId,
2487                                    relaxedMemoryOrderId,
2488                                    this->writeExpression(*arguments[1], out),
2489                                    out);
2490             break;
2491         case kAtomicLoad_SpecialIntrinsic:
2492             SkASSERT(arguments.size() == 1);
2493             this->writeInstruction(SpvOpAtomicLoad,
2494                                    this->getType(c.type()),
2495                                    resultId,
2496                                    atomicPtrId,
2497                                    memoryScopeId,
2498                                    relaxedMemoryOrderId,
2499                                    out);
2500             break;
2501         case kAtomicStore_SpecialIntrinsic:
2502             SkASSERT(arguments.size() == 2);
2503             this->writeInstruction(SpvOpAtomicStore,
2504                                    atomicPtrId,
2505                                    memoryScopeId,
2506                                    relaxedMemoryOrderId,
2507                                    this->writeExpression(*arguments[1], out),
2508                                    out);
2509             break;
2510         default:
2511             SkUNREACHABLE;
2512     }
2513 
2514     return resultId;
2515 }
2516 
writeFunctionCallArgument(TArray<SpvId> & argumentList,const FunctionCall & call,int argIndex,std::vector<TempVar> * tempVars,const SkBitSet * specializedParams,OutputStream & out)2517 void SPIRVCodeGenerator::writeFunctionCallArgument(TArray<SpvId>& argumentList,
2518                                                    const FunctionCall& call,
2519                                                    int argIndex,
2520                                                    std::vector<TempVar>* tempVars,
2521                                                    const SkBitSet* specializedParams,
2522                                                    OutputStream& out) {
2523     const FunctionDeclaration& funcDecl = call.function();
2524     const Expression& arg = *call.arguments()[argIndex];
2525     const Variable* param = funcDecl.parameters()[argIndex];
2526     bool paramIsSpecialized = specializedParams && specializedParams->test(argIndex);
2527     ModifierFlags paramFlags = param->modifierFlags();
2528 
2529     // Ignore the argument since it is specialized, if fUseTextureSamplerPairs is true and this
2530     // argument is a sampler, handle ignoring the sampler below when generating the texture and
2531     // sampler pair arguments.
2532     if (paramIsSpecialized && !(param->type().isSampler() && fUseTextureSamplerPairs)) {
2533         return;
2534     }
2535 
2536     if (arg.is<VariableReference>() && (arg.type().typeKind() == Type::TypeKind::kSampler ||
2537                                         arg.type().typeKind() == Type::TypeKind::kSeparateSampler ||
2538                                         arg.type().typeKind() == Type::TypeKind::kTexture)) {
2539         // Opaque handle (sampler/texture) arguments are always declared as pointers but never
2540         // stored in intermediates when calling user-defined functions.
2541         //
2542         // The case for intrinsics (which take opaque arguments by value) is handled above just like
2543         // regular pointers.
2544         //
2545         // See getFunctionParameterType for further explanation.
2546         const Variable* var = arg.as<VariableReference>().variable();
2547 
2548         // In Dawn-mode the texture and sampler arguments are forwarded to the helper function.
2549         if (fUseTextureSamplerPairs && var->type().isSampler()) {
2550             if (const auto* p = fSynthesizedSamplerMap.find(var)) {
2551                 SpvId* img = fVariableMap.find((*p)->fTexture.get());
2552                 SkASSERT(img);
2553 
2554                 argumentList.push_back(*img);
2555 
2556                 if (!paramIsSpecialized) {
2557                     SpvId* sampler = fVariableMap.find((*p)->fSampler.get());
2558                     SkASSERT(sampler);
2559                     argumentList.push_back(*sampler);
2560                 }
2561                 return;
2562             }
2563             SkDEBUGFAIL("sampler missing from fSynthesizedSamplerMap");
2564         }
2565 
2566         SpvId* entry = fVariableMap.find(var);
2567         SkASSERTF(entry, "%s", arg.description().c_str());
2568         argumentList.push_back(*entry);
2569         return;
2570     }
2571     SkASSERT(!paramIsSpecialized);
2572 
2573     // ID of temporary variable that we will use to hold this argument, or 0 if it is being
2574     // passed directly
2575     SpvId tmpVar = NA;
2576     // if we need a temporary var to store this argument, this is the value to store in the var
2577     SpvId tmpValueId = NA;
2578 
2579     if (is_out(paramFlags)) {
2580         std::unique_ptr<LValue> lv = this->getLValue(arg, out);
2581         // We handle out params with a temp var that we copy back to the original variable at the
2582         // end of the call. GLSL guarantees that the original variable will be unchanged until the
2583         // end of the call, and also that out params are written back to their original variables in
2584         // a specific order (left-to-right), so it's unsafe to pass a pointer to the original value.
2585         if (is_in(paramFlags)) {
2586             tmpValueId = lv->load(out);
2587         }
2588         tmpVar = this->nextId(&arg.type());
2589         tempVars->push_back(TempVar{tmpVar, &arg.type(), std::move(lv)});
2590     } else if (funcDecl.isIntrinsic()) {
2591         // Unlike user function calls, non-out intrinsic arguments don't need pointer parameters.
2592         argumentList.push_back(this->writeExpression(arg, out));
2593         return;
2594     } else {
2595         // We always use pointer parameters when calling user functions.
2596         // See getFunctionParameterType for further explanation.
2597         tmpValueId = this->writeExpression(arg, out);
2598         tmpVar = this->nextId(nullptr);
2599     }
2600     this->writeInstruction(SpvOpVariable,
2601                            this->getPointerType(arg.type(), StorageClass::kFunction),
2602                            tmpVar,
2603                            SpvStorageClassFunction,
2604                            fVariableBuffer);
2605     if (tmpValueId != NA) {
2606         this->writeOpStore(StorageClass::kFunction, tmpVar, tmpValueId, out);
2607     }
2608     argumentList.push_back(tmpVar);
2609 }
2610 
copyBackTempVars(const std::vector<TempVar> & tempVars,OutputStream & out)2611 void SPIRVCodeGenerator::copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out) {
2612     for (const TempVar& tempVar : tempVars) {
2613         SpvId load = this->nextId(tempVar.type);
2614         this->writeInstruction(SpvOpLoad, this->getType(*tempVar.type), load, tempVar.spvId, out);
2615         tempVar.lvalue->store(load, out);
2616     }
2617 }
2618 
writeFunctionCall(const FunctionCall & c,OutputStream & out)2619 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
2620     // Handle intrinsics.
2621     const FunctionDeclaration& function = c.function();
2622     if (function.isIntrinsic() && !function.definition()) {
2623         return this->writeIntrinsicCall(c, out);
2624     }
2625 
2626     // Look up this function (or its specialization, if any) in our map of function SpvIds.
2627     Analysis::SpecializationIndex specializationIndex = Analysis::FindSpecializationIndexForCall(
2628             c, fSpecializationInfo, fActiveSpecializationIndex);
2629     SpvId* entry = fFunctionMap.find({&function, specializationIndex});
2630     if (!entry) {
2631         fContext.fErrors->error(c.fPosition, "function '" + function.description() +
2632                                              "' is not defined");
2633         return NA;
2634     }
2635 
2636     // If we are calling a specialized function, we need to gather the specialized parameters
2637     // so we can remove them from the argument list.
2638     SkBitSet specializedParams =
2639             Analysis::FindSpecializedParametersForFunction(c.function(), fSpecializationInfo);
2640 
2641     // Temp variables are used to write back out-parameters after the function call is complete.
2642     const ExpressionArray& arguments = c.arguments();
2643     std::vector<TempVar> tempVars;
2644     TArray<SpvId> argumentIds;
2645     argumentIds.reserve_exact(arguments.size());
2646     for (int i = 0; i < arguments.size(); i++) {
2647         this->writeFunctionCallArgument(argumentIds, c, i, &tempVars, &specializedParams, out);
2648     }
2649     SpvId result = this->nextId(nullptr);
2650     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t)argumentIds.size(), out);
2651     this->writeWord(this->getType(c.type()), out);
2652     this->writeWord(result, out);
2653     this->writeWord(*entry, out);
2654     for (SpvId id : argumentIds) {
2655         this->writeWord(id, out);
2656     }
2657     // Now that the call is complete, we copy temp out-variables back to their real lvalues.
2658     this->copyBackTempVars(tempVars, out);
2659     return result;
2660 }
2661 
castScalarToType(SpvId inputExprId,const Type & inputType,const Type & outputType,OutputStream & out)2662 SpvId SPIRVCodeGenerator::castScalarToType(SpvId inputExprId,
2663                                            const Type& inputType,
2664                                            const Type& outputType,
2665                                            OutputStream& out) {
2666     if (outputType.isFloat()) {
2667         return this->castScalarToFloat(inputExprId, inputType, outputType, out);
2668     }
2669     if (outputType.isSigned()) {
2670         return this->castScalarToSignedInt(inputExprId, inputType, outputType, out);
2671     }
2672     if (outputType.isUnsigned()) {
2673         return this->castScalarToUnsignedInt(inputExprId, inputType, outputType, out);
2674     }
2675     if (outputType.isBoolean()) {
2676         return this->castScalarToBoolean(inputExprId, inputType, outputType, out);
2677     }
2678 
2679     fContext.fErrors->error(Position(), "unsupported cast: " + inputType.description() + " to " +
2680             outputType.description());
2681     return inputExprId;
2682 }
2683 
castScalarToFloat(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2684 SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType,
2685                                             const Type& outputType, OutputStream& out) {
2686     // Casting a float to float is a no-op.
2687     if (inputType.isFloat()) {
2688         return inputId;
2689     }
2690 
2691     // Given the input type, generate the appropriate instruction to cast to float.
2692     SpvId result = this->nextId(&outputType);
2693     if (inputType.isBoolean()) {
2694         // Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0.
2695         const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fFloat);
2696         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2697         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2698                                inputId, oneID, zeroID, out);
2699     } else if (inputType.isSigned()) {
2700         this->writeInstruction(SpvOpConvertSToF, this->getType(outputType), result, inputId, out);
2701     } else if (inputType.isUnsigned()) {
2702         this->writeInstruction(SpvOpConvertUToF, this->getType(outputType), result, inputId, out);
2703     } else {
2704         SkDEBUGFAILF("unsupported type for float typecast: %s", inputType.description().c_str());
2705         return NA;
2706     }
2707     return result;
2708 }
2709 
castScalarToSignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2710 SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& inputType,
2711                                                 const Type& outputType, OutputStream& out) {
2712     // Casting a signed int to signed int is a no-op.
2713     if (inputType.isSigned()) {
2714         return inputId;
2715     }
2716 
2717     // Given the input type, generate the appropriate instruction to cast to signed int.
2718     SpvId result = this->nextId(&outputType);
2719     if (inputType.isBoolean()) {
2720         // Use OpSelect to convert the boolean argument to a literal 1 or 0.
2721         const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fInt);
2722         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
2723         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2724                                inputId, oneID, zeroID, out);
2725     } else if (inputType.isFloat()) {
2726         this->writeInstruction(SpvOpConvertFToS, this->getType(outputType), result, inputId, out);
2727     } else if (inputType.isUnsigned()) {
2728         this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
2729     } else {
2730         SkDEBUGFAILF("unsupported type for signed int typecast: %s",
2731                      inputType.description().c_str());
2732         return NA;
2733     }
2734     return result;
2735 }
2736 
castScalarToUnsignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2737 SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inputType,
2738                                                   const Type& outputType, OutputStream& out) {
2739     // Casting an unsigned int to unsigned int is a no-op.
2740     if (inputType.isUnsigned()) {
2741         return inputId;
2742     }
2743 
2744     // Given the input type, generate the appropriate instruction to cast to unsigned int.
2745     SpvId result = this->nextId(&outputType);
2746     if (inputType.isBoolean()) {
2747         // Use OpSelect to convert the boolean argument to a literal 1u or 0u.
2748         const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fUInt);
2749         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
2750         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2751                                inputId, oneID, zeroID, out);
2752     } else if (inputType.isFloat()) {
2753         this->writeInstruction(SpvOpConvertFToU, this->getType(outputType), result, inputId, out);
2754     } else if (inputType.isSigned()) {
2755         this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
2756     } else {
2757         SkDEBUGFAILF("unsupported type for unsigned int typecast: %s",
2758                      inputType.description().c_str());
2759         return NA;
2760     }
2761     return result;
2762 }
2763 
castScalarToBoolean(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2764 SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputType,
2765                                               const Type& outputType, OutputStream& out) {
2766     // Casting a bool to bool is a no-op.
2767     if (inputType.isBoolean()) {
2768         return inputId;
2769     }
2770 
2771     // Given the input type, generate the appropriate instruction to cast to bool.
2772     SpvId result = this->nextId(nullptr);
2773     if (inputType.isSigned()) {
2774         // Synthesize a boolean result by comparing the input against a signed zero literal.
2775         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
2776         this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
2777                                inputId, zeroID, out);
2778     } else if (inputType.isUnsigned()) {
2779         // Synthesize a boolean result by comparing the input against an unsigned zero literal.
2780         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
2781         this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
2782                                inputId, zeroID, out);
2783     } else if (inputType.isFloat()) {
2784         // Synthesize a boolean result by comparing the input against a floating-point zero literal.
2785         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2786         this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result,
2787                                inputId, zeroID, out);
2788     } else {
2789         SkDEBUGFAILF("unsupported type for boolean typecast: %s", inputType.description().c_str());
2790         return NA;
2791     }
2792     return result;
2793 }
2794 
writeMatrixCopy(SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)2795 SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType,
2796                                           OutputStream& out) {
2797     SkASSERT(srcType.isMatrix());
2798     SkASSERT(dstType.isMatrix());
2799     SkASSERT(srcType.componentType().matches(dstType.componentType()));
2800     const Type& srcColumnType = srcType.componentType().toCompound(fContext, srcType.rows(), 1);
2801     const Type& dstColumnType = dstType.componentType().toCompound(fContext, dstType.rows(), 1);
2802     SkASSERT(dstType.componentType().isFloat());
2803     SpvId dstColumnTypeId = this->getType(dstColumnType);
2804     const SpvId zeroId = this->writeLiteral(0.0, dstType.componentType());
2805     const SpvId oneId = this->writeLiteral(1.0, dstType.componentType());
2806 
2807     STArray<4, SpvId> columns;
2808     for (int i = 0; i < dstType.columns(); i++) {
2809         if (i < srcType.columns()) {
2810             // we're still inside the src matrix, copy the column
2811             SpvId srcColumn = this->writeOpCompositeExtract(srcColumnType, src, i, out);
2812             SpvId dstColumn;
2813             if (srcType.rows() == dstType.rows()) {
2814                 // columns are equal size, don't need to do anything
2815                 dstColumn = srcColumn;
2816             }
2817             else if (dstType.rows() > srcType.rows()) {
2818                 // dst column is bigger, need to zero-pad it
2819                 STArray<4, SpvId> values;
2820                 values.push_back(srcColumn);
2821                 for (int j = srcType.rows(); j < dstType.rows(); ++j) {
2822                     values.push_back((i == j) ? oneId : zeroId);
2823                 }
2824                 dstColumn = this->writeOpCompositeConstruct(dstColumnType, values, out);
2825             }
2826             else {
2827                 // dst column is smaller, need to swizzle the src column
2828                 dstColumn = this->nextId(&dstType);
2829                 this->writeOpCode(SpvOpVectorShuffle, 5 + dstType.rows(), out);
2830                 this->writeWord(dstColumnTypeId, out);
2831                 this->writeWord(dstColumn, out);
2832                 this->writeWord(srcColumn, out);
2833                 this->writeWord(srcColumn, out);
2834                 for (int j = 0; j < dstType.rows(); j++) {
2835                     this->writeWord(j, out);
2836                 }
2837             }
2838             columns.push_back(dstColumn);
2839         } else {
2840             // we're past the end of the src matrix, need to synthesize an identity-matrix column
2841             STArray<4, SpvId> values;
2842             for (int j = 0; j < dstType.rows(); ++j) {
2843                 values.push_back((i == j) ? oneId : zeroId);
2844             }
2845             columns.push_back(this->writeOpCompositeConstruct(dstColumnType, values, out));
2846         }
2847     }
2848 
2849     return this->writeOpCompositeConstruct(dstType, columns, out);
2850 }
2851 
addColumnEntry(const Type & columnType,TArray<SpvId> * currentColumn,TArray<SpvId> * columnIds,int rows,SpvId entry,OutputStream & out)2852 void SPIRVCodeGenerator::addColumnEntry(const Type& columnType,
2853                                         TArray<SpvId>* currentColumn,
2854                                         TArray<SpvId>* columnIds,
2855                                         int rows,
2856                                         SpvId entry,
2857                                         OutputStream& out) {
2858     SkASSERT(currentColumn->size() < rows);
2859     currentColumn->push_back(entry);
2860     if (currentColumn->size() == rows) {
2861         // Synthesize this column into a vector.
2862         SpvId columnId = this->writeOpCompositeConstruct(columnType, *currentColumn, out);
2863         columnIds->push_back(columnId);
2864         currentColumn->clear();
2865     }
2866 }
2867 
writeMatrixConstructor(const ConstructorCompound & c,OutputStream & out)2868 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out) {
2869     const Type& type = c.type();
2870     SkASSERT(type.isMatrix());
2871     SkASSERT(!c.arguments().empty());
2872     const Type& arg0Type = c.arguments()[0]->type();
2873     // go ahead and write the arguments so we don't try to write new instructions in the middle of
2874     // an instruction
2875     STArray<16, SpvId> arguments;
2876     for (const std::unique_ptr<Expression>& arg : c.arguments()) {
2877         arguments.push_back(this->writeExpression(*arg, out));
2878     }
2879 
2880     if (arguments.size() == 1 && arg0Type.isVector()) {
2881         // Special-case handling of float4 -> mat2x2.
2882         SkASSERT(type.rows() == 2 && type.columns() == 2);
2883         SkASSERT(arg0Type.columns() == 4);
2884         SpvId v[4];
2885         for (int i = 0; i < 4; ++i) {
2886             v[i] = this->writeOpCompositeExtract(type.componentType(), arguments[0], i, out);
2887         }
2888         const Type& vecType = type.columnType(fContext);
2889         SpvId v0v1 = this->writeOpCompositeConstruct(vecType, {v[0], v[1]}, out);
2890         SpvId v2v3 = this->writeOpCompositeConstruct(vecType, {v[2], v[3]}, out);
2891         return this->writeOpCompositeConstruct(type, {v0v1, v2v3}, out);
2892     }
2893 
2894     int rows = type.rows();
2895     const Type& columnType = type.columnType(fContext);
2896     // SpvIds of completed columns of the matrix.
2897     STArray<4, SpvId> columnIds;
2898     // SpvIds of scalars we have written to the current column so far.
2899     STArray<4, SpvId> currentColumn;
2900     for (int i = 0; i < arguments.size(); i++) {
2901         const Type& argType = c.arguments()[i]->type();
2902         if (currentColumn.empty() && argType.isVector() && argType.columns() == rows) {
2903             // This vector is a complete matrix column by itself and can be used as-is.
2904             columnIds.push_back(arguments[i]);
2905         } else if (argType.columns() == 1) {
2906             // This argument is a lone scalar and can be added to the current column as-is.
2907             this->addColumnEntry(columnType, &currentColumn, &columnIds, rows, arguments[i], out);
2908         } else {
2909             // This argument needs to be decomposed into its constituent scalars.
2910             for (int j = 0; j < argType.columns(); ++j) {
2911                 SpvId swizzle = this->writeOpCompositeExtract(argType.componentType(),
2912                                                               arguments[i], j, out);
2913                 this->addColumnEntry(columnType, &currentColumn, &columnIds, rows, swizzle, out);
2914             }
2915         }
2916     }
2917     SkASSERT(columnIds.size() == type.columns());
2918     return this->writeOpCompositeConstruct(type, columnIds, out);
2919 }
2920 
writeConstructorCompound(const ConstructorCompound & c,OutputStream & out)2921 SpvId SPIRVCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
2922                                                    OutputStream& out) {
2923     return c.type().isMatrix() ? this->writeMatrixConstructor(c, out)
2924                                : this->writeVectorConstructor(c, out);
2925 }
2926 
writeVectorConstructor(const ConstructorCompound & c,OutputStream & out)2927 SpvId SPIRVCodeGenerator::writeVectorConstructor(const ConstructorCompound& c, OutputStream& out) {
2928     const Type& type = c.type();
2929     const Type& componentType = type.componentType();
2930     SkASSERT(type.isVector());
2931 
2932     STArray<4, SpvId> arguments;
2933     for (int i = 0; i < c.arguments().size(); i++) {
2934         const Type& argType = c.arguments()[i]->type();
2935         SkASSERT(componentType.numberKind() == argType.componentType().numberKind());
2936 
2937         SpvId arg = this->writeExpression(*c.arguments()[i], out);
2938         if (argType.isMatrix()) {
2939             // CompositeConstruct cannot take a 2x2 matrix as an input, so we need to extract out
2940             // each scalar separately.
2941             SkASSERT(argType.rows() == 2);
2942             SkASSERT(argType.columns() == 2);
2943             for (int j = 0; j < 4; ++j) {
2944                 arguments.push_back(this->writeOpCompositeExtract(componentType, arg,
2945                                                                   j / 2, j % 2, out));
2946             }
2947         } else if (argType.isVector()) {
2948             // There's a bug in the Intel Vulkan driver where OpCompositeConstruct doesn't handle
2949             // vector arguments at all, so we always extract each vector component and pass them
2950             // into OpCompositeConstruct individually.
2951             for (int j = 0; j < argType.columns(); j++) {
2952                 arguments.push_back(this->writeOpCompositeExtract(componentType, arg, j, out));
2953             }
2954         } else {
2955             arguments.push_back(arg);
2956         }
2957     }
2958 
2959     return this->writeOpCompositeConstruct(type, arguments, out);
2960 }
2961 
writeConstructorSplat(const ConstructorSplat & c,OutputStream & out)2962 SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) {
2963     // Write the splat argument as a scalar, then splat it.
2964     SpvId argument = this->writeExpression(*c.argument(), out);
2965     return this->splat(c.type(), argument, out);
2966 }
2967 
writeCompositeConstructor(const AnyConstructor & c,OutputStream & out)2968 SpvId SPIRVCodeGenerator::writeCompositeConstructor(const AnyConstructor& c, OutputStream& out) {
2969     SkASSERT(c.type().isArray() || c.type().isStruct());
2970     auto ctorArgs = c.argumentSpan();
2971 
2972     STArray<4, SpvId> arguments;
2973     for (const std::unique_ptr<Expression>& arg : ctorArgs) {
2974         arguments.push_back(this->writeExpression(*arg, out));
2975     }
2976 
2977     return this->writeOpCompositeConstruct(c.type(), arguments, out);
2978 }
2979 
writeConstructorScalarCast(const ConstructorScalarCast & c,OutputStream & out)2980 SpvId SPIRVCodeGenerator::writeConstructorScalarCast(const ConstructorScalarCast& c,
2981                                                      OutputStream& out) {
2982     const Type& type = c.type();
2983     if (type.componentType().numberKind() == c.argument()->type().componentType().numberKind()) {
2984         return this->writeExpression(*c.argument(), out);
2985     }
2986 
2987     const Expression& ctorExpr = *c.argument();
2988     SpvId expressionId = this->writeExpression(ctorExpr, out);
2989     return this->castScalarToType(expressionId, ctorExpr.type(), type, out);
2990 }
2991 
writeConstructorCompoundCast(const ConstructorCompoundCast & c,OutputStream & out)2992 SpvId SPIRVCodeGenerator::writeConstructorCompoundCast(const ConstructorCompoundCast& c,
2993                                                        OutputStream& out) {
2994     const Type& ctorType = c.type();
2995     const Type& argType = c.argument()->type();
2996     SkASSERT(ctorType.isVector() || ctorType.isMatrix());
2997 
2998     // Write the composite that we are casting. If the actual type matches, we are done.
2999     SpvId compositeId = this->writeExpression(*c.argument(), out);
3000     if (ctorType.componentType().numberKind() == argType.componentType().numberKind()) {
3001         return compositeId;
3002     }
3003 
3004     // writeMatrixCopy can cast matrices to a different type.
3005     if (ctorType.isMatrix()) {
3006         return this->writeMatrixCopy(compositeId, argType, ctorType, out);
3007     }
3008 
3009     // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to extract the
3010     // components and convert each one manually.
3011     const Type& srcType = argType.componentType();
3012     const Type& dstType = ctorType.componentType();
3013 
3014     STArray<4, SpvId> arguments;
3015     for (int index = 0; index < argType.columns(); ++index) {
3016         SpvId componentId = this->writeOpCompositeExtract(srcType, compositeId, index, out);
3017         arguments.push_back(this->castScalarToType(componentId, srcType, dstType, out));
3018     }
3019 
3020     return this->writeOpCompositeConstruct(ctorType, arguments, out);
3021 }
3022 
writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c,OutputStream & out)3023 SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
3024                                                          OutputStream& out) {
3025     const Type& type = c.type();
3026     SkASSERT(type.isMatrix());
3027     SkASSERT(c.argument()->type().isScalar());
3028 
3029     // Write out the scalar argument.
3030     SpvId diagonal = this->writeExpression(*c.argument(), out);
3031 
3032     // Build the diagonal matrix.
3033     SpvId zeroId = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
3034 
3035     const Type& vecType = type.columnType(fContext);
3036     STArray<4, SpvId> columnIds;
3037     STArray<4, SpvId> arguments;
3038     arguments.resize(type.rows());
3039     for (int column = 0; column < type.columns(); column++) {
3040         for (int row = 0; row < type.rows(); row++) {
3041             arguments[row] = (row == column) ? diagonal : zeroId;
3042         }
3043         columnIds.push_back(this->writeOpCompositeConstruct(vecType, arguments, out));
3044     }
3045     return this->writeOpCompositeConstruct(type, columnIds, out);
3046 }
3047 
writeConstructorMatrixResize(const ConstructorMatrixResize & c,OutputStream & out)3048 SpvId SPIRVCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
3049                                                        OutputStream& out) {
3050     // Write the input matrix.
3051     SpvId argument = this->writeExpression(*c.argument(), out);
3052 
3053     // Use matrix-copy to resize the input matrix to its new size.
3054     return this->writeMatrixCopy(argument, c.argument()->type(), c.type(), out);
3055 }
3056 
get_storage_class_for_global_variable(const Variable & var,StorageClass fallbackStorageClass)3057 static StorageClass get_storage_class_for_global_variable(
3058         const Variable& var, StorageClass fallbackStorageClass) {
3059     SkASSERT(var.storage() == Variable::Storage::kGlobal);
3060 
3061     if (var.type().typeKind() == Type::TypeKind::kSampler ||
3062         var.type().typeKind() == Type::TypeKind::kSeparateSampler ||
3063         var.type().typeKind() == Type::TypeKind::kTexture) {
3064         return StorageClass::kUniformConstant;
3065     }
3066 
3067     const Layout& layout = var.layout();
3068     ModifierFlags flags = var.modifierFlags();
3069     if (flags & ModifierFlag::kIn) {
3070         SkASSERT(!(layout.fFlags & LayoutFlag::kPushConstant));
3071         return StorageClass::kInput;
3072     }
3073     if (flags & ModifierFlag::kOut) {
3074         SkASSERT(!(layout.fFlags & LayoutFlag::kPushConstant));
3075         return StorageClass::kOutput;
3076     }
3077     if (flags.isUniform()) {
3078         if (layout.fFlags & LayoutFlag::kPushConstant) {
3079             return StorageClass::kPushConstant;
3080         }
3081         return StorageClass::kUniform;
3082     }
3083     if (flags.isBuffer()) {
3084         return StorageClass::kStorageBuffer;
3085     }
3086     if (flags.isWorkgroup()) {
3087         return StorageClass::kWorkgroup;
3088     }
3089     return fallbackStorageClass;
3090 }
3091 
getStorageClass(const Expression & expr)3092 StorageClass SPIRVCodeGenerator::getStorageClass(const Expression& expr) {
3093     switch (expr.kind()) {
3094         case Expression::Kind::kVariableReference: {
3095             const Variable& var = *expr.as<VariableReference>().variable();
3096             if (fActiveSpecialization) {
3097                 const Expression** specializedExpr = fActiveSpecialization->find(&var);
3098                 if (specializedExpr && (*specializedExpr)->is<FieldAccess>()) {
3099                     return this->getStorageClass(**specializedExpr);
3100                 }
3101             }
3102             if (var.storage() != Variable::Storage::kGlobal) {
3103                 return StorageClass::kFunction;
3104             }
3105             return get_storage_class_for_global_variable(var, StorageClass::kPrivate);
3106         }
3107         case Expression::Kind::kFieldAccess:
3108             return this->getStorageClass(*expr.as<FieldAccess>().base());
3109         case Expression::Kind::kIndex:
3110             return this->getStorageClass(*expr.as<IndexExpression>().base());
3111         default:
3112             return StorageClass::kFunction;
3113     }
3114 }
3115 
getAccessChain(const Expression & expr,OutputStream & out)3116 TArray<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
3117     switch (expr.kind()) {
3118         case Expression::Kind::kIndex: {
3119             const IndexExpression& indexExpr = expr.as<IndexExpression>();
3120             if (indexExpr.base()->is<Swizzle>()) {
3121                 // Access chains don't directly support dynamically indexing into a swizzle, but we
3122                 // can rewrite them into a supported form.
3123                 return this->getAccessChain(*Transform::RewriteIndexedSwizzle(fContext, indexExpr),
3124                                             out);
3125             }
3126             // All other index-expressions can be represented as typical access chains.
3127             TArray<SpvId> chain = this->getAccessChain(*indexExpr.base(), out);
3128             chain.push_back(this->writeExpression(*indexExpr.index(), out));
3129             return chain;
3130         }
3131         case Expression::Kind::kFieldAccess: {
3132             const FieldAccess& fieldExpr = expr.as<FieldAccess>();
3133             TArray<SpvId> chain = this->getAccessChain(*fieldExpr.base(), out);
3134             chain.push_back(this->writeLiteral(fieldExpr.fieldIndex(), *fContext.fTypes.fInt));
3135             return chain;
3136         }
3137         case Expression::Kind::kVariableReference: {
3138             if (fActiveSpecialization) {
3139                 const Expression** specializedFieldIndex =
3140                         fActiveSpecialization->find(expr.as<VariableReference>().variable());
3141                 if (specializedFieldIndex && (*specializedFieldIndex)->is<FieldAccess>()) {
3142                     return this->getAccessChain(**specializedFieldIndex, out);
3143                 }
3144             }
3145             [[fallthrough]];
3146         }
3147         default: {
3148             SpvId id = this->getLValue(expr, out)->getPointer();
3149             SkASSERT(id != NA);
3150             return TArray<SpvId>{id};
3151         }
3152     }
3153     SkUNREACHABLE;
3154 }
3155 
3156 class PointerLValue : public SPIRVCodeGenerator::LValue {
3157 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,bool isMemoryObject,SpvId type,SPIRVCodeGenerator::Precision precision,StorageClass storageClass)3158     PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, bool isMemoryObject, SpvId type,
3159                   SPIRVCodeGenerator::Precision precision, StorageClass storageClass)
3160     : fGen(gen)
3161     , fPointer(pointer)
3162     , fIsMemoryObject(isMemoryObject)
3163     , fType(type)
3164     , fPrecision(precision)
3165     , fStorageClass(storageClass) {}
3166 
getPointer()3167     SpvId getPointer() override {
3168         return fPointer;
3169     }
3170 
isMemoryObjectPointer() const3171     bool isMemoryObjectPointer() const override {
3172         return fIsMemoryObject;
3173     }
3174 
storageClass() const3175     StorageClass storageClass() const override {
3176         return fStorageClass;
3177     }
3178 
load(OutputStream & out)3179     SpvId load(OutputStream& out) override {
3180         return fGen.writeOpLoad(fType, fPrecision, fPointer, out);
3181     }
3182 
store(SpvId value,OutputStream & out)3183     void store(SpvId value, OutputStream& out) override {
3184         if (!fIsMemoryObject) {
3185             // We are going to write into an access chain; this could represent one component of a
3186             // vector, or one element of an array. This has the potential to invalidate other,
3187             // *unknown* elements of our store cache. (e.g. if the store cache holds `%50 = myVec4`,
3188             // and we store `%60 = myVec4.z`, this invalidates the cached value for %50.) To avoid
3189             // relying on stale data, reset the store cache entirely when this happens.
3190             fGen.fStoreCache.reset();
3191         }
3192 
3193         fGen.writeOpStore(fStorageClass, fPointer, value, out);
3194     }
3195 
3196 private:
3197     SPIRVCodeGenerator& fGen;
3198     const SpvId fPointer;
3199     const bool fIsMemoryObject;
3200     const SpvId fType;
3201     const SPIRVCodeGenerator::Precision fPrecision;
3202     const StorageClass fStorageClass;
3203 };
3204 
3205 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
3206 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const ComponentArray & components,const Type & baseType,const Type & swizzleType,StorageClass storageClass)3207     SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const ComponentArray& components,
3208                   const Type& baseType, const Type& swizzleType, StorageClass storageClass)
3209     : fGen(gen)
3210     , fVecPointer(vecPointer)
3211     , fComponents(components)
3212     , fBaseType(&baseType)
3213     , fSwizzleType(&swizzleType)
3214     , fStorageClass(storageClass) {}
3215 
applySwizzle(const ComponentArray & components,const Type & newType)3216     bool applySwizzle(const ComponentArray& components, const Type& newType) override {
3217         ComponentArray updatedSwizzle;
3218         for (int8_t component : components) {
3219             if (component < 0 || component >= fComponents.size()) {
3220                 SkDEBUGFAILF("swizzle accessed nonexistent component %d", (int)component);
3221                 return false;
3222             }
3223             updatedSwizzle.push_back(fComponents[component]);
3224         }
3225         fComponents = updatedSwizzle;
3226         fSwizzleType = &newType;
3227         return true;
3228     }
3229 
storageClass() const3230     StorageClass storageClass() const override {
3231         return fStorageClass;
3232     }
3233 
load(OutputStream & out)3234     SpvId load(OutputStream& out) override {
3235         SpvId base = fGen.nextId(fBaseType);
3236         fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
3237         SpvId result = fGen.nextId(fBaseType);
3238         fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
3239         fGen.writeWord(fGen.getType(*fSwizzleType), out);
3240         fGen.writeWord(result, out);
3241         fGen.writeWord(base, out);
3242         fGen.writeWord(base, out);
3243         for (int component : fComponents) {
3244             fGen.writeWord(component, out);
3245         }
3246         return result;
3247     }
3248 
store(SpvId value,OutputStream & out)3249     void store(SpvId value, OutputStream& out) override {
3250         // use OpVectorShuffle to mix and match the vector components. We effectively create
3251         // a virtual vector out of the concatenation of the left and right vectors, and then
3252         // select components from this virtual vector to make the result vector. For
3253         // instance, given:
3254         // float3L = ...;
3255         // float3R = ...;
3256         // L.xz = R.xy;
3257         // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
3258         // our result vector to look like (R.x, L.y, R.y), so we need to select indices
3259         // (3, 1, 4).
3260         SpvId base = fGen.nextId(fBaseType);
3261         fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
3262         SpvId shuffle = fGen.nextId(fBaseType);
3263         fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType->columns(), out);
3264         fGen.writeWord(fGen.getType(*fBaseType), out);
3265         fGen.writeWord(shuffle, out);
3266         fGen.writeWord(base, out);
3267         fGen.writeWord(value, out);
3268         for (int i = 0; i < fBaseType->columns(); i++) {
3269             // current offset into the virtual vector, defaults to pulling the unmodified
3270             // value from the left side
3271             int offset = i;
3272             // check to see if we are writing this component
3273             for (int j = 0; j < fComponents.size(); j++) {
3274                 if (fComponents[j] == i) {
3275                     // we're writing to this component, so adjust the offset to pull from
3276                     // the correct component of the right side instead of preserving the
3277                     // value from the left
3278                     offset = (int) (j + fBaseType->columns());
3279                     break;
3280                 }
3281             }
3282             fGen.writeWord(offset, out);
3283         }
3284         fGen.writeOpStore(fStorageClass, fVecPointer, shuffle, out);
3285     }
3286 
3287 private:
3288     SPIRVCodeGenerator& fGen;
3289     const SpvId fVecPointer;
3290     ComponentArray fComponents;
3291     const Type* fBaseType;
3292     const Type* fSwizzleType;
3293     const StorageClass fStorageClass;
3294 };
3295 
findUniformFieldIndex(const Variable & var) const3296 int SPIRVCodeGenerator::findUniformFieldIndex(const Variable& var) const {
3297     int* fieldIndex = fTopLevelUniformMap.find(&var);
3298     return fieldIndex ? *fieldIndex : -1;
3299 }
3300 
getLValue(const Expression & expr,OutputStream & out)3301 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
3302                                                                           OutputStream& out) {
3303     const Type& type = expr.type();
3304     Precision precision = type.highPrecision() ? Precision::kDefault : Precision::kRelaxed;
3305     switch (expr.kind()) {
3306         case Expression::Kind::kVariableReference: {
3307             const Variable& var = *expr.as<VariableReference>().variable();
3308             int uniformIdx = this->findUniformFieldIndex(var);
3309             if (uniformIdx >= 0) {
3310                 // Access uniforms via an AccessChain into the uniform-buffer struct.
3311                 SpvId memberId = this->nextId(nullptr);
3312                 SpvId typeId = this->getPointerType(type, StorageClass::kUniform);
3313                 SpvId uniformIdxId = this->writeLiteral((double)uniformIdx, *fContext.fTypes.fInt);
3314                 this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId,
3315                                        uniformIdxId, out);
3316                 return std::make_unique<PointerLValue>(
3317                         *this,
3318                         memberId,
3319                         /*isMemoryObjectPointer=*/true,
3320                         this->getType(type, kDefaultTypeLayout, this->memoryLayoutForVariable(var)),
3321                         precision,
3322                         StorageClass::kUniform);
3323             }
3324 
3325             SpvId* entry = fVariableMap.find(&var);
3326             SkASSERTF(entry, "%s", expr.description().c_str());
3327 
3328             if (var.layout().fBuiltin == SK_SAMPLEMASKIN_BUILTIN ||
3329                 var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3330                 // Access sk_SampleMask and sk_SampleMaskIn via an array access, since Vulkan
3331                 // represents sample masks as an array of uints.
3332                 StorageClass storageClass =
3333                         get_storage_class_for_global_variable(var, StorageClass::kPrivate);
3334                 SkASSERT(storageClass != StorageClass::kPrivate);
3335                 SkASSERT(type.matches(*fContext.fTypes.fUInt));
3336 
3337                 SpvId accessId = this->nextId(nullptr);
3338                 SpvId typeId = this->getPointerType(type, storageClass);
3339                 SpvId indexId = this->writeLiteral(0.0, *fContext.fTypes.fInt);
3340                 this->writeInstruction(SpvOpAccessChain, typeId, accessId, *entry, indexId, out);
3341                 return std::make_unique<PointerLValue>(*this,
3342                                                        accessId,
3343                                                        /*isMemoryObjectPointer=*/true,
3344                                                        this->getType(type),
3345                                                        precision,
3346                                                        storageClass);
3347             }
3348 
3349             SpvId typeId = this->getType(type, var.layout(), this->memoryLayoutForVariable(var));
3350             return std::make_unique<PointerLValue>(*this, *entry,
3351                                                    /*isMemoryObjectPointer=*/true,
3352                                                    typeId, precision, this->getStorageClass(expr));
3353         }
3354         case Expression::Kind::kIndex: // fall through
3355         case Expression::Kind::kFieldAccess: {
3356             TArray<SpvId> chain = this->getAccessChain(expr, out);
3357             SpvId member = this->nextId(nullptr);
3358             StorageClass storageClass = this->getStorageClass(expr);
3359             this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
3360             this->writeWord(this->getPointerType(type, storageClass), out);
3361             this->writeWord(member, out);
3362             for (SpvId idx : chain) {
3363                 this->writeWord(idx, out);
3364             }
3365             return std::make_unique<PointerLValue>(
3366                     *this,
3367                     member,
3368                     /*isMemoryObjectPointer=*/false,
3369                     this->getType(type,
3370                                   kDefaultTypeLayout,
3371                                   this->memoryLayoutForStorageClass(storageClass)),
3372                     precision,
3373                     storageClass);
3374         }
3375         case Expression::Kind::kSwizzle: {
3376             const Swizzle& swizzle = expr.as<Swizzle>();
3377             std::unique_ptr<LValue> lvalue = this->getLValue(*swizzle.base(), out);
3378             if (lvalue->applySwizzle(swizzle.components(), type)) {
3379                 return lvalue;
3380             }
3381             SpvId base = lvalue->getPointer();
3382             if (base == NA) {
3383                 fContext.fErrors->error(swizzle.fPosition,
3384                         "unable to retrieve lvalue from swizzle");
3385             }
3386             StorageClass storageClass = this->getStorageClass(*swizzle.base());
3387             if (swizzle.components().size() == 1) {
3388                 SpvId member = this->nextId(nullptr);
3389                 SpvId typeId = this->getPointerType(type, storageClass);
3390                 SpvId indexId = this->writeLiteral(swizzle.components()[0], *fContext.fTypes.fInt);
3391                 this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out);
3392                 return std::make_unique<PointerLValue>(*this, member,
3393                                                        /*isMemoryObjectPointer=*/false,
3394                                                        this->getType(type),
3395                                                        precision, storageClass);
3396             } else {
3397                 return std::make_unique<SwizzleLValue>(*this, base, swizzle.components(),
3398                                                        swizzle.base()->type(), type, storageClass);
3399             }
3400         }
3401         default: {
3402             // expr isn't actually an lvalue, create a placeholder variable for it. This case
3403             // happens due to the need to store values in temporary variables during function
3404             // calls (see comments in getFunctionParameterType); erroneous uses of rvalues as
3405             // lvalues should have been caught before code generation.
3406             //
3407             // This is with the exception of opaque handle types (textures/samplers) which are
3408             // always defined as UniformConstant pointers and don't need to be explicitly stored
3409             // into a temporary (which is handled explicitly in writeFunctionCallArgument).
3410             SpvId result = this->nextId(nullptr);
3411             SpvId pointerType = this->getPointerType(type, StorageClass::kFunction);
3412             this->writeInstruction(SpvOpVariable, pointerType, result, SpvStorageClassFunction,
3413                                    fVariableBuffer);
3414             this->writeOpStore(StorageClass::kFunction, result, this->writeExpression(expr, out),
3415                                out);
3416             return std::make_unique<PointerLValue>(*this, result, /*isMemoryObjectPointer=*/true,
3417                                                    this->getType(type), precision,
3418                                                    StorageClass::kFunction);
3419         }
3420     }
3421 }
3422 
identifier(std::string_view name)3423 std::unique_ptr<Expression> SPIRVCodeGenerator::identifier(std::string_view name) {
3424     std::unique_ptr<Expression> expr =
3425             fProgram.fSymbols->instantiateSymbolRef(fContext, name, Position());
3426     return expr ? std::move(expr)
3427                 : Poison::Make(Position(), fContext);
3428 }
3429 
writeVariableReference(const VariableReference & ref,OutputStream & out)3430 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
3431     const Variable* variable = ref.variable();
3432     switch (variable->layout().fBuiltin) {
3433         case DEVICE_FRAGCOORDS_BUILTIN: {
3434             // Down below, we rewrite raw references to sk_FragCoord with expressions that reference
3435             // DEVICE_FRAGCOORDS_BUILTIN. This is a fake variable that means we need to directly
3436             // access the fragcoord; do so now.
3437             return this->getLValue(*this->identifier("sk_FragCoord"), out)->load(out);
3438         }
3439         case DEVICE_CLOCKWISE_BUILTIN: {
3440             // Down below, we rewrite raw references to sk_Clockwise with expressions that reference
3441             // DEVICE_CLOCKWISE_BUILTIN. This is a fake variable that means we need to directly
3442             // access front facing; do so now.
3443             return this->getLValue(*this->identifier("sk_Clockwise"), out)->load(out);
3444         }
3445         case SK_SECONDARYFRAGCOLOR_BUILTIN: {
3446             if (fCaps.fDualSourceBlendingSupport) {
3447                 return this->getLValue(*this->identifier("sk_SecondaryFragColor"), out)->load(out);
3448             } else {
3449                 fContext.fErrors->error(ref.position(), "'sk_SecondaryFragColor' not supported");
3450                 return NA;
3451             }
3452         }
3453         case SK_FRAGCOORD_BUILTIN: {
3454             if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
3455                 return this->getLValue(*this->identifier("sk_FragCoord"), out)->load(out);
3456             }
3457 
3458             // Handle inserting use of uniform to flip y when referencing sk_FragCoord.
3459             this->addRTFlipUniform(ref.fPosition);
3460             // Use sk_RTAdjust to compute the flipped coordinate
3461             // Use a uniform to flip the Y coordinate. The new expression will be written in
3462             // terms of $device_FragCoords, which is a fake variable that means "access the
3463             // underlying fragcoords directly without flipping it".
3464             static constexpr char DEVICE_COORDS_NAME[] = "$device_FragCoords";
3465             if (!fProgram.fSymbols->find(DEVICE_COORDS_NAME)) {
3466                 AutoAttachPoolToThread attach(fProgram.fPool.get());
3467                 Layout layout;
3468                 layout.fBuiltin = DEVICE_FRAGCOORDS_BUILTIN;
3469                 auto coordsVar = Variable::Make(/*pos=*/Position(),
3470                                                 /*modifiersPosition=*/Position(),
3471                                                 layout,
3472                                                 ModifierFlag::kNone,
3473                                                 fContext.fTypes.fFloat4.get(),
3474                                                 DEVICE_COORDS_NAME,
3475                                                 /*mangledName=*/"",
3476                                                 /*builtin=*/true,
3477                                                 Variable::Storage::kGlobal);
3478                 fProgram.fSymbols->add(fContext, std::move(coordsVar));
3479             }
3480             std::unique_ptr<Expression> deviceCoord = this->identifier(DEVICE_COORDS_NAME);
3481             std::unique_ptr<Expression> rtFlip = this->identifier(SKSL_RTFLIP_NAME);
3482             SpvId rtFlipX = this->writeSwizzle(*rtFlip, {SwizzleComponent::X}, out);
3483             SpvId rtFlipY = this->writeSwizzle(*rtFlip, {SwizzleComponent::Y}, out);
3484             SpvId deviceCoordX  = this->writeSwizzle(*deviceCoord, {SwizzleComponent::X}, out);
3485             SpvId deviceCoordY  = this->writeSwizzle(*deviceCoord, {SwizzleComponent::Y}, out);
3486             SpvId deviceCoordZW = this->writeSwizzle(*deviceCoord, {SwizzleComponent::Z,
3487                                                                     SwizzleComponent::W}, out);
3488             // Compute `flippedY = u_RTFlip.y * $device_FragCoords.y`.
3489             SpvId flippedY = this->writeBinaryExpression(
3490                                      *fContext.fTypes.fFloat, rtFlipY, OperatorKind::STAR,
3491                                      *fContext.fTypes.fFloat, deviceCoordY,
3492                                      *fContext.fTypes.fFloat, out);
3493 
3494             // Compute `flippedY = u_RTFlip.x + flippedY`.
3495             flippedY = this->writeBinaryExpression(
3496                                *fContext.fTypes.fFloat, rtFlipX, OperatorKind::PLUS,
3497                                *fContext.fTypes.fFloat, flippedY,
3498                                *fContext.fTypes.fFloat, out);
3499 
3500             // Return `float4(deviceCoord.x, flippedY, deviceCoord.zw)`.
3501             return this->writeOpCompositeConstruct(*fContext.fTypes.fFloat4,
3502                                                    {deviceCoordX, flippedY, deviceCoordZW},
3503                                                    out);
3504         }
3505         case SK_CLOCKWISE_BUILTIN: {
3506             if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
3507                 return this->getLValue(*this->identifier("sk_Clockwise"), out)->load(out);
3508             }
3509 
3510             // Apply RTFlip to sk_Clockwise.
3511             this->addRTFlipUniform(ref.fPosition);
3512             // Use a uniform to flip the Y coordinate. The new expression will be written in
3513             // terms of $device_Clockwise, which is a fake variable that means "access the
3514             // underlying FrontFacing directly".
3515             static constexpr char DEVICE_CLOCKWISE_NAME[] = "$device_Clockwise";
3516             if (!fProgram.fSymbols->find(DEVICE_CLOCKWISE_NAME)) {
3517                 AutoAttachPoolToThread attach(fProgram.fPool.get());
3518                 Layout layout;
3519                 layout.fBuiltin = DEVICE_CLOCKWISE_BUILTIN;
3520                 auto clockwiseVar = Variable::Make(/*pos=*/Position(),
3521                                                    /*modifiersPosition=*/Position(),
3522                                                    layout,
3523                                                    ModifierFlag::kNone,
3524                                                    fContext.fTypes.fBool.get(),
3525                                                    DEVICE_CLOCKWISE_NAME,
3526                                                    /*mangledName=*/"",
3527                                                    /*builtin=*/true,
3528                                                    Variable::Storage::kGlobal);
3529                 fProgram.fSymbols->add(fContext, std::move(clockwiseVar));
3530             }
3531             // FrontFacing in Vulkan is defined in terms of a top-down render target. In Skia,
3532             // we use the default convention of "counter-clockwise face is front".
3533 
3534             // Compute `positiveRTFlip = (rtFlip.y > 0)`.
3535             std::unique_ptr<Expression> rtFlip = this->identifier(SKSL_RTFLIP_NAME);
3536             SpvId rtFlipY = this->writeSwizzle(*rtFlip, {SwizzleComponent::Y}, out);
3537             SpvId zero = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
3538             SpvId positiveRTFlip = this->writeBinaryExpression(
3539                                            *fContext.fTypes.fFloat, rtFlipY, OperatorKind::GT,
3540                                            *fContext.fTypes.fFloat, zero,
3541                                            *fContext.fTypes.fBool, out);
3542 
3543             // Compute `positiveRTFlip ^^ $device_Clockwise`.
3544             std::unique_ptr<Expression> deviceClockwise = this->identifier(DEVICE_CLOCKWISE_NAME);
3545             SpvId deviceClockwiseID = this->writeExpression(*deviceClockwise, out);
3546             return this->writeBinaryExpression(
3547                            *fContext.fTypes.fBool, positiveRTFlip, OperatorKind::LOGICALXOR,
3548                            *fContext.fTypes.fBool, deviceClockwiseID,
3549                            *fContext.fTypes.fBool, out);
3550         }
3551         default: {
3552             // Constant-propagate variables that have a known compile-time value.
3553             if (const Expression* expr = ConstantFolder::GetConstantValueOrNull(ref)) {
3554                 return this->writeExpression(*expr, out);
3555             }
3556 
3557             // A reference to a sampler variable at global scope with synthesized texture/sampler
3558             // backing should construct a function-scope combined image-sampler from the synthesized
3559             // constituents. This is the case in which a sample intrinsic was invoked.
3560             //
3561             // Variable references to opaque handles (texture/sampler) that appear as the argument
3562             // of a user-defined function call are explicitly handled in writeFunctionCallArgument.
3563             if (fUseTextureSamplerPairs && variable->type().isSampler()) {
3564                 if (const auto* p = fSynthesizedSamplerMap.find(variable)) {
3565                     SpvId* imgPtr = fVariableMap.find((*p)->fTexture.get());
3566                     SpvId* samplerPtr = fVariableMap.find((*p)->fSampler.get());
3567                     SkASSERT(imgPtr);
3568                     SkASSERT(samplerPtr);
3569 
3570                     SpvId img = this->writeOpLoad(this->getType((*p)->fTexture->type()),
3571                                                   Precision::kDefault, *imgPtr, out);
3572                     SpvId sampler = this->writeOpLoad(this->getType((*p)->fSampler->type()),
3573                                                       Precision::kDefault,
3574                                                       *samplerPtr,
3575                                                       out);
3576                     SpvId result = this->nextId(nullptr);
3577                     this->writeInstruction(SpvOpSampledImage,
3578                                            this->getType(variable->type()),
3579                                            result,
3580                                            img,
3581                                            sampler,
3582                                            out);
3583                     return result;
3584                 }
3585                 SkDEBUGFAIL("sampler missing from fSynthesizedSamplerMap");
3586             }
3587             return this->getLValue(ref, out)->load(out);
3588         }
3589     }
3590 }
3591 
writeIndexExpression(const IndexExpression & expr,OutputStream & out)3592 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
3593     if (expr.base()->type().isVector()) {
3594         SpvId base = this->writeExpression(*expr.base(), out);
3595         SpvId index = this->writeExpression(*expr.index(), out);
3596         SpvId result = this->nextId(nullptr);
3597         this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.type()), result, base,
3598                                index, out);
3599         return result;
3600     }
3601     return getLValue(expr, out)->load(out);
3602 }
3603 
writeFieldAccess(const FieldAccess & f,OutputStream & out)3604 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
3605     return getLValue(f, out)->load(out);
3606 }
3607 
writeSwizzle(const Expression & baseExpr,const ComponentArray & components,OutputStream & out)3608 SpvId SPIRVCodeGenerator::writeSwizzle(const Expression& baseExpr,
3609                                        const ComponentArray& components,
3610                                        OutputStream& out) {
3611     size_t count = components.size();
3612     const Type& type = baseExpr.type().componentType().toCompound(fContext, count, /*rows=*/1);
3613     SpvId base = this->writeExpression(baseExpr, out);
3614     if (count == 1) {
3615         return this->writeOpCompositeExtract(type, base, components[0], out);
3616     }
3617 
3618     SpvId result = this->nextId(&type);
3619     this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
3620     this->writeWord(this->getType(type), out);
3621     this->writeWord(result, out);
3622     this->writeWord(base, out);
3623     this->writeWord(base, out);
3624     for (int component : components) {
3625         this->writeWord(component, out);
3626     }
3627     return result;
3628 }
3629 
writeSwizzle(const Swizzle & swizzle,OutputStream & out)3630 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
3631     return this->writeSwizzle(*swizzle.base(), swizzle.components(), out);
3632 }
3633 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,bool writeComponentwiseIfMatrix,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3634 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, const Type& operandType,
3635                                                SpvId lhs, SpvId rhs,
3636                                                bool writeComponentwiseIfMatrix,
3637                                                SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
3638                                                SpvOp_ ifBool, OutputStream& out) {
3639     SpvOp_ op = pick_by_type(operandType, ifFloat, ifInt, ifUInt, ifBool);
3640     if (op == SpvOpUndef) {
3641         fContext.fErrors->error(operandType.fPosition,
3642                 "unsupported operand for binary expression: " + operandType.description());
3643         return NA;
3644     }
3645     if (writeComponentwiseIfMatrix && operandType.isMatrix()) {
3646         return this->writeComponentwiseMatrixBinary(resultType, lhs, rhs, op, out);
3647     }
3648     SpvId result = this->nextId(&resultType);
3649     this->writeInstruction(op, this->getType(resultType), result, lhs, rhs, out);
3650     return result;
3651 }
3652 
writeBinaryOperationComponentwiseIfMatrix(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3653 SpvId SPIRVCodeGenerator::writeBinaryOperationComponentwiseIfMatrix(const Type& resultType,
3654                                                                     const Type& operandType,
3655                                                                     SpvId lhs, SpvId rhs,
3656                                                                     SpvOp_ ifFloat, SpvOp_ ifInt,
3657                                                                     SpvOp_ ifUInt, SpvOp_ ifBool,
3658                                                                     OutputStream& out) {
3659     return this->writeBinaryOperation(resultType, operandType, lhs, rhs,
3660                                       /*writeComponentwiseIfMatrix=*/true,
3661                                       ifFloat, ifInt, ifUInt, ifBool, out);
3662 }
3663 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3664 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, const Type& operandType,
3665                                                SpvId lhs, SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
3666                                                SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
3667     return this->writeBinaryOperation(resultType, operandType, lhs, rhs,
3668                                       /*writeComponentwiseIfMatrix=*/false,
3669                                       ifFloat, ifInt, ifUInt, ifBool, out);
3670 }
3671 
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)3672 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
3673                                      OutputStream& out) {
3674     if (operandType.isVector()) {
3675         SpvId result = this->nextId(nullptr);
3676         this->writeInstruction(op, this->getType(*fContext.fTypes.fBool), result, id, out);
3677         return result;
3678     }
3679     return id;
3680 }
3681 
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)3682 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
3683                                                 SpvOp_ floatOperator, SpvOp_ intOperator,
3684                                                 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
3685                                                 OutputStream& out) {
3686     SpvOp_ compareOp = is_float(operandType) ? floatOperator : intOperator;
3687     SkASSERT(operandType.isMatrix());
3688     const Type& columnType = operandType.componentType().toCompound(fContext,
3689                                                                     operandType.rows(),
3690                                                                     1);
3691     SpvId bvecType = this->getType(fContext.fTypes.fBool->toCompound(fContext,
3692                                                                      operandType.rows(),
3693                                                                      1));
3694     SpvId boolType = this->getType(*fContext.fTypes.fBool);
3695     SpvId result = 0;
3696     for (int i = 0; i < operandType.columns(); i++) {
3697         SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
3698         SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
3699         SpvId compare = this->nextId(&operandType);
3700         this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
3701         SpvId merge = this->nextId(nullptr);
3702         this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
3703         if (result != 0) {
3704             SpvId next = this->nextId(nullptr);
3705             this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
3706             result = next;
3707         } else {
3708             result = merge;
3709         }
3710     }
3711     return result;
3712 }
3713 
writeComponentwiseMatrixUnary(const Type & operandType,SpvId operand,SpvOp_ op,OutputStream & out)3714 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixUnary(const Type& operandType,
3715                                                         SpvId operand,
3716                                                         SpvOp_ op,
3717                                                         OutputStream& out) {
3718     SkASSERT(operandType.isMatrix());
3719     const Type& columnType = operandType.columnType(fContext);
3720     SpvId columnTypeId = this->getType(columnType);
3721 
3722     STArray<4, SpvId> columns;
3723     for (int i = 0; i < operandType.columns(); i++) {
3724         SpvId srcColumn = this->writeOpCompositeExtract(columnType, operand, i, out);
3725         SpvId dstColumn = this->nextId(&operandType);
3726         this->writeInstruction(op, columnTypeId, dstColumn, srcColumn, out);
3727         columns.push_back(dstColumn);
3728     }
3729 
3730     return this->writeOpCompositeConstruct(operandType, columns, out);
3731 }
3732 
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ op,OutputStream & out)3733 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
3734                                                          SpvId rhs, SpvOp_ op, OutputStream& out) {
3735     SkASSERT(operandType.isMatrix());
3736     const Type& columnType = operandType.columnType(fContext);
3737     SpvId columnTypeId = this->getType(columnType);
3738 
3739     STArray<4, SpvId> columns;
3740     for (int i = 0; i < operandType.columns(); i++) {
3741         SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
3742         SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
3743         columns.push_back(this->nextId(&operandType));
3744         this->writeInstruction(op, columnTypeId, columns[i], columnL, columnR, out);
3745     }
3746     return this->writeOpCompositeConstruct(operandType, columns, out);
3747 }
3748 
writeReciprocal(const Type & type,SpvId value,OutputStream & out)3749 SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) {
3750     SkASSERT(type.isFloat());
3751     SpvId one = this->writeLiteral(1.0, type);
3752     SpvId reciprocal = this->nextId(&type);
3753     this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out);
3754     return reciprocal;
3755 }
3756 
splat(const Type & type,SpvId id,OutputStream & out)3757 SpvId SPIRVCodeGenerator::splat(const Type& type, SpvId id, OutputStream& out) {
3758     if (type.isScalar()) {
3759         // Scalars require no additional work; we can return the passed-in ID as is.
3760     } else {
3761         SkASSERT(type.isVector() || type.isMatrix());
3762         bool isMatrix = type.isMatrix();
3763 
3764         // Splat the input scalar across a vector.
3765         int vectorSize = (isMatrix ? type.rows() : type.columns());
3766         const Type& vectorType = type.componentType().toCompound(fContext, vectorSize, /*rows=*/1);
3767 
3768         STArray<4, SpvId> values;
3769         values.push_back_n(/*n=*/vectorSize, /*t=*/id);
3770         id = this->writeOpCompositeConstruct(vectorType, values, out);
3771 
3772         if (isMatrix) {
3773             // Splat the newly-synthesized vector into a matrix.
3774             STArray<4, SpvId> matArguments;
3775             matArguments.push_back_n(/*n=*/type.columns(), /*t=*/id);
3776             id = this->writeOpCompositeConstruct(type, matArguments, out);
3777         }
3778     }
3779 
3780     return id;
3781 }
3782 
types_match(const Type & a,const Type & b)3783 static bool types_match(const Type& a, const Type& b) {
3784     if (a.matches(b)) {
3785         return true;
3786     }
3787     return (a.typeKind() == b.typeKind()) &&
3788            (a.isScalar() || a.isVector() || a.isMatrix()) &&
3789            (a.columns() == b.columns() && a.rows() == b.rows()) &&
3790            a.componentType().numberKind() == b.componentType().numberKind();
3791 }
3792 
writeDecomposedMatrixVectorMultiply(const Type & leftType,SpvId lhs,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)3793 SpvId SPIRVCodeGenerator::writeDecomposedMatrixVectorMultiply(const Type& leftType,
3794                                                               SpvId lhs,
3795                                                               const Type& rightType,
3796                                                               SpvId rhs,
3797                                                               const Type& resultType,
3798                                                               OutputStream& out) {
3799     SpvId sum = NA;
3800     const Type& columnType = leftType.columnType(fContext);
3801     const Type& scalarType = rightType.componentType();
3802 
3803     for (int n = 0; n < leftType.rows(); ++n) {
3804         // Extract mat[N] from the matrix.
3805         SpvId matN = this->writeOpCompositeExtract(columnType, lhs, n, out);
3806 
3807         // Extract vec[N] from the vector.
3808         SpvId vecN = this->writeOpCompositeExtract(scalarType, rhs, n, out);
3809 
3810         // Multiply them together.
3811         SpvId product = this->writeBinaryExpression(columnType, matN, OperatorKind::STAR,
3812                                                     scalarType, vecN,
3813                                                     columnType, out);
3814 
3815         // Sum all the components together.
3816         if (sum == NA) {
3817             sum = product;
3818         } else {
3819             sum = this->writeBinaryExpression(columnType, sum, OperatorKind::PLUS,
3820                                               columnType, product,
3821                                               columnType, out);
3822         }
3823     }
3824 
3825     return sum;
3826 }
3827 
writeBinaryExpression(const Type & leftType,SpvId lhs,Operator op,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)3828 SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
3829                                                 const Type& rightType, SpvId rhs,
3830                                                 const Type& resultType, OutputStream& out) {
3831     // The comma operator ignores the type of the left-hand side entirely.
3832     if (op.kind() == Operator::Kind::COMMA) {
3833         return rhs;
3834     }
3835     // overall type we are operating on: float2, int, uint4...
3836     const Type* operandType;
3837     if (types_match(leftType, rightType)) {
3838         operandType = &leftType;
3839     } else {
3840         // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
3841         // handling in SPIR-V
3842         if (leftType.isVector() && rightType.isNumber()) {
3843             if (resultType.componentType().isFloat()) {
3844                 switch (op.kind()) {
3845                     case Operator::Kind::SLASH: {
3846                         rhs = this->writeReciprocal(rightType, rhs, out);
3847                         [[fallthrough]];
3848                     }
3849                     case Operator::Kind::STAR: {
3850                         SpvId result = this->nextId(&resultType);
3851                         this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
3852                                                result, lhs, rhs, out);
3853                         return result;
3854                     }
3855                     default:
3856                         break;
3857                 }
3858             }
3859             // Vectorize the right-hand side.
3860             STArray<4, SpvId> arguments;
3861             arguments.push_back_n(/*n=*/leftType.columns(), /*t=*/rhs);
3862             rhs = this->writeOpCompositeConstruct(leftType, arguments, out);
3863             operandType = &leftType;
3864         } else if (rightType.isVector() && leftType.isNumber()) {
3865             if (resultType.componentType().isFloat()) {
3866                 if (op.kind() == Operator::Kind::STAR) {
3867                     SpvId result = this->nextId(&resultType);
3868                     this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
3869                                            result, rhs, lhs, out);
3870                     return result;
3871                 }
3872             }
3873             // Vectorize the left-hand side.
3874             STArray<4, SpvId> arguments;
3875             arguments.push_back_n(/*n=*/rightType.columns(), /*t=*/lhs);
3876             lhs = this->writeOpCompositeConstruct(rightType, arguments, out);
3877             operandType = &rightType;
3878         } else if (leftType.isMatrix()) {
3879             if (op.kind() == Operator::Kind::STAR) {
3880                 // When the rewriteMatrixVectorMultiply bit is set, we rewrite medium-precision
3881                 // matrix * vector multiplication as (mat[0]*vec[0] + ... + mat[N]*vec[N]).
3882                 if (fCaps.fRewriteMatrixVectorMultiply &&
3883                     rightType.isVector() &&
3884                     !resultType.highPrecision()) {
3885                     return this->writeDecomposedMatrixVectorMultiply(leftType, lhs, rightType, rhs,
3886                                                                      resultType, out);
3887                 }
3888 
3889                 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
3890                 SpvOp_ spvop;
3891                 if (rightType.isMatrix()) {
3892                     spvop = SpvOpMatrixTimesMatrix;
3893                 } else if (rightType.isVector()) {
3894                     spvop = SpvOpMatrixTimesVector;
3895                 } else {
3896                     SkASSERT(rightType.isScalar());
3897                     spvop = SpvOpMatrixTimesScalar;
3898                 }
3899                 SpvId result = this->nextId(&resultType);
3900                 this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
3901                 return result;
3902             } else {
3903                 // Matrix-op-vector is not supported in GLSL/SkSL for non-multiplication ops; we
3904                 // expect to have a scalar here.
3905                 SkASSERT(rightType.isScalar());
3906 
3907                 // Splat rhs across an entire matrix so we can reuse the matrix-op-matrix path.
3908                 SpvId rhsMatrix = this->splat(leftType, rhs, out);
3909 
3910                 // Perform this operation as matrix-op-matrix.
3911                 return this->writeBinaryExpression(leftType, lhs, op, leftType, rhsMatrix,
3912                                                    resultType, out);
3913             }
3914         } else if (rightType.isMatrix()) {
3915             if (op.kind() == Operator::Kind::STAR) {
3916                 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
3917                 SpvId result = this->nextId(&resultType);
3918                 if (leftType.isVector()) {
3919                     this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType),
3920                                            result, lhs, rhs, out);
3921                 } else {
3922                     SkASSERT(leftType.isScalar());
3923                     this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType),
3924                                            result, rhs, lhs, out);
3925                 }
3926                 return result;
3927             } else {
3928                 // Vector-op-matrix is not supported in GLSL/SkSL for non-multiplication ops; we
3929                 // expect to have a scalar here.
3930                 SkASSERT(leftType.isScalar());
3931 
3932                 // Splat lhs across an entire matrix so we can reuse the matrix-op-matrix path.
3933                 SpvId lhsMatrix = this->splat(rightType, lhs, out);
3934 
3935                 // Perform this operation as matrix-op-matrix.
3936                 return this->writeBinaryExpression(rightType, lhsMatrix, op, rightType, rhs,
3937                                                    resultType, out);
3938             }
3939         } else {
3940             fContext.fErrors->error(leftType.fPosition, "unsupported mixed-type expression");
3941             return NA;
3942         }
3943     }
3944 
3945     switch (op.kind()) {
3946         case Operator::Kind::EQEQ: {
3947             if (operandType->isMatrix()) {
3948                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
3949                                                    SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
3950             }
3951             if (operandType->isStruct()) {
3952                 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
3953             }
3954             if (operandType->isArray()) {
3955                 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
3956             }
3957             SkASSERT(resultType.isBoolean());
3958             const Type* tmpType;
3959             if (operandType->isVector()) {
3960                 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
3961                                                              operandType->columns(),
3962                                                              operandType->rows());
3963             } else {
3964                 tmpType = &resultType;
3965             }
3966             if (lhs == rhs) {
3967                 // This ignores the effects of NaN.
3968                 return this->writeOpConstantTrue(*fContext.fTypes.fBool);
3969             }
3970             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
3971                                                                SpvOpFOrdEqual, SpvOpIEqual,
3972                                                                SpvOpIEqual, SpvOpLogicalEqual, out),
3973                                     *operandType, SpvOpAll, out);
3974         }
3975         case Operator::Kind::NEQ:
3976             if (operandType->isMatrix()) {
3977                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFUnordNotEqual,
3978                                                    SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
3979             }
3980             if (operandType->isStruct()) {
3981                 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
3982             }
3983             if (operandType->isArray()) {
3984                 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
3985             }
3986             [[fallthrough]];
3987         case Operator::Kind::LOGICALXOR:
3988             SkASSERT(resultType.isBoolean());
3989             const Type* tmpType;
3990             if (operandType->isVector()) {
3991                 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
3992                                                              operandType->columns(),
3993                                                              operandType->rows());
3994             } else {
3995                 tmpType = &resultType;
3996             }
3997             if (lhs == rhs) {
3998                 // This ignores the effects of NaN.
3999                 return this->writeOpConstantFalse(*fContext.fTypes.fBool);
4000             }
4001             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
4002                                                                SpvOpFUnordNotEqual, SpvOpINotEqual,
4003                                                                SpvOpINotEqual, SpvOpLogicalNotEqual,
4004                                                                out),
4005                                     *operandType, SpvOpAny, out);
4006         case Operator::Kind::GT:
4007             SkASSERT(resultType.isBoolean());
4008             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
4009                                               SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
4010                                               SpvOpUGreaterThan, SpvOpUndef, out);
4011         case Operator::Kind::LT:
4012             SkASSERT(resultType.isBoolean());
4013             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
4014                                               SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
4015         case Operator::Kind::GTEQ:
4016             SkASSERT(resultType.isBoolean());
4017             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
4018                                               SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
4019                                               SpvOpUGreaterThanEqual, SpvOpUndef, out);
4020         case Operator::Kind::LTEQ:
4021             SkASSERT(resultType.isBoolean());
4022             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
4023                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
4024                                               SpvOpULessThanEqual, SpvOpUndef, out);
4025         case Operator::Kind::PLUS:
4026             return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
4027                                                                    lhs, rhs, SpvOpFAdd, SpvOpIAdd,
4028                                                                    SpvOpIAdd, SpvOpUndef, out);
4029         case Operator::Kind::MINUS:
4030             return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
4031                                                                    lhs, rhs, SpvOpFSub, SpvOpISub,
4032                                                                    SpvOpISub, SpvOpUndef, out);
4033         case Operator::Kind::STAR:
4034             if (leftType.isMatrix() && rightType.isMatrix()) {
4035                 // matrix multiply
4036                 SpvId result = this->nextId(&resultType);
4037                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
4038                                        lhs, rhs, out);
4039                 return result;
4040             }
4041             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
4042                                               SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
4043         case Operator::Kind::SLASH:
4044             return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
4045                                                                    lhs, rhs, SpvOpFDiv, SpvOpSDiv,
4046                                                                    SpvOpUDiv, SpvOpUndef, out);
4047         case Operator::Kind::PERCENT:
4048             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
4049                                               SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
4050         case Operator::Kind::SHL:
4051             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4052                                               SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
4053                                               SpvOpUndef, out);
4054         case Operator::Kind::SHR:
4055             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4056                                               SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
4057                                               SpvOpUndef, out);
4058         case Operator::Kind::BITWISEAND:
4059             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4060                                               SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
4061         case Operator::Kind::BITWISEOR:
4062             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4063                                               SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
4064         case Operator::Kind::BITWISEXOR:
4065             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4066                                               SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
4067         default:
4068             fContext.fErrors->error(Position(), "unsupported token");
4069             return NA;
4070     }
4071 }
4072 
writeArrayComparison(const Type & arrayType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)4073 SpvId SPIRVCodeGenerator::writeArrayComparison(const Type& arrayType, SpvId lhs, Operator op,
4074                                                SpvId rhs, OutputStream& out) {
4075     // The inputs must be arrays, and the op must be == or !=.
4076     SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
4077     SkASSERT(arrayType.isArray());
4078     const Type& componentType = arrayType.componentType();
4079     const int arraySize = arrayType.columns();
4080     SkASSERT(arraySize > 0);
4081 
4082     // Synthesize equality checks for each item in the array.
4083     const Type& boolType = *fContext.fTypes.fBool;
4084     SpvId allComparisons = NA;
4085     for (int index = 0; index < arraySize; ++index) {
4086         // Get the left and right item in the array.
4087         SpvId itemL = this->writeOpCompositeExtract(componentType, lhs, index, out);
4088         SpvId itemR = this->writeOpCompositeExtract(componentType, rhs, index, out);
4089         // Use `writeBinaryExpression` with the requested == or != operator on these items.
4090         SpvId comparison = this->writeBinaryExpression(componentType, itemL, op,
4091                                                        componentType, itemR, boolType, out);
4092         // Merge this comparison result with all the other comparisons we've done.
4093         allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
4094     }
4095     return allComparisons;
4096 }
4097 
writeStructComparison(const Type & structType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)4098 SpvId SPIRVCodeGenerator::writeStructComparison(const Type& structType, SpvId lhs, Operator op,
4099                                                 SpvId rhs, OutputStream& out) {
4100     // The inputs must be structs containing fields, and the op must be == or !=.
4101     SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
4102     SkASSERT(structType.isStruct());
4103     SkSpan<const Field> fields = structType.fields();
4104     SkASSERT(!fields.empty());
4105 
4106     // Synthesize equality checks for each field in the struct.
4107     const Type& boolType = *fContext.fTypes.fBool;
4108     SpvId allComparisons = NA;
4109     for (int index = 0; index < (int)fields.size(); ++index) {
4110         // Get the left and right versions of this field.
4111         const Type& fieldType = *fields[index].fType;
4112 
4113         SpvId fieldL = this->writeOpCompositeExtract(fieldType, lhs, index, out);
4114         SpvId fieldR = this->writeOpCompositeExtract(fieldType, rhs, index, out);
4115         // Use `writeBinaryExpression` with the requested == or != operator on these fields.
4116         SpvId comparison = this->writeBinaryExpression(fieldType, fieldL, op, fieldType, fieldR,
4117                                                        boolType, out);
4118         // Merge this comparison result with all the other comparisons we've done.
4119         allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
4120     }
4121     return allComparisons;
4122 }
4123 
mergeComparisons(SpvId comparison,SpvId allComparisons,Operator op,OutputStream & out)4124 SpvId SPIRVCodeGenerator::mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op,
4125                                            OutputStream& out) {
4126     // If this is the first entry, we don't need to merge comparison results with anything.
4127     if (allComparisons == NA) {
4128         return comparison;
4129     }
4130     // Use LogicalAnd or LogicalOr to combine the comparison with all the other comparisons.
4131     const Type& boolType = *fContext.fTypes.fBool;
4132     SpvId boolTypeId = this->getType(boolType);
4133     SpvId logicalOp = this->nextId(&boolType);
4134     switch (op.kind()) {
4135         case Operator::Kind::EQEQ:
4136             this->writeInstruction(SpvOpLogicalAnd, boolTypeId, logicalOp,
4137                                    comparison, allComparisons, out);
4138             break;
4139         case Operator::Kind::NEQ:
4140             this->writeInstruction(SpvOpLogicalOr, boolTypeId, logicalOp,
4141                                    comparison, allComparisons, out);
4142             break;
4143         default:
4144             SkDEBUGFAILF("mergeComparisons only supports == and !=, not %s", op.operatorName());
4145             return NA;
4146     }
4147     return logicalOp;
4148 }
4149 
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)4150 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
4151     const Expression* left = b.left().get();
4152     const Expression* right = b.right().get();
4153     Operator op = b.getOperator();
4154 
4155     switch (op.kind()) {
4156         case Operator::Kind::EQ: {
4157             // Handles assignment.
4158             SpvId rhs = this->writeExpression(*right, out);
4159             this->getLValue(*left, out)->store(rhs, out);
4160             return rhs;
4161         }
4162         case Operator::Kind::LOGICALAND:
4163             // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
4164             return this->writeLogicalAnd(*b.left(), *b.right(), out);
4165 
4166         case Operator::Kind::LOGICALOR:
4167             // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
4168             return this->writeLogicalOr(*b.left(), *b.right(), out);
4169 
4170         default:
4171             break;
4172     }
4173 
4174     std::unique_ptr<LValue> lvalue;
4175     SpvId lhs;
4176     if (op.isAssignment()) {
4177         lvalue = this->getLValue(*left, out);
4178         lhs = lvalue->load(out);
4179     } else {
4180         lvalue = nullptr;
4181         lhs = this->writeExpression(*left, out);
4182     }
4183 
4184     SpvId rhs = this->writeExpression(*right, out);
4185     SpvId result = this->writeBinaryExpression(left->type(), lhs, op.removeAssignment(),
4186                                                right->type(), rhs, b.type(), out);
4187     if (lvalue) {
4188         lvalue->store(result, out);
4189     }
4190     return result;
4191 }
4192 
writeLogicalAnd(const Expression & left,const Expression & right,OutputStream & out)4193 SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right,
4194                                           OutputStream& out) {
4195     SpvId falseConstant = this->writeLiteral(0.0, *fContext.fTypes.fBool);
4196     SpvId lhs = this->writeExpression(left, out);
4197 
4198     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4199 
4200     SpvId rhsLabel = this->nextId(nullptr);
4201     SpvId end = this->nextId(nullptr);
4202     SpvId lhsBlock = fCurrentBlock;
4203     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4204     this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
4205     this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
4206     SpvId rhs = this->writeExpression(right, out);
4207     SpvId rhsBlock = fCurrentBlock;
4208     this->writeInstruction(SpvOpBranch, end, out);
4209     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4210     SpvId result = this->nextId(nullptr);
4211     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, falseConstant,
4212                            lhsBlock, rhs, rhsBlock, out);
4213 
4214     return result;
4215 }
4216 
writeLogicalOr(const Expression & left,const Expression & right,OutputStream & out)4217 SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right,
4218                                          OutputStream& out) {
4219     SpvId trueConstant = this->writeLiteral(1.0, *fContext.fTypes.fBool);
4220     SpvId lhs = this->writeExpression(left, out);
4221 
4222     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4223 
4224     SpvId rhsLabel = this->nextId(nullptr);
4225     SpvId end = this->nextId(nullptr);
4226     SpvId lhsBlock = fCurrentBlock;
4227     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4228     this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
4229     this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
4230     SpvId rhs = this->writeExpression(right, out);
4231     SpvId rhsBlock = fCurrentBlock;
4232     this->writeInstruction(SpvOpBranch, end, out);
4233     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4234     SpvId result = this->nextId(nullptr);
4235     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, trueConstant,
4236                            lhsBlock, rhs, rhsBlock, out);
4237 
4238     return result;
4239 }
4240 
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)4241 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
4242     const Type& type = t.type();
4243     SpvId test = this->writeExpression(*t.test(), out);
4244     if (t.ifTrue()->type().columns() == 1 &&
4245         Analysis::IsCompileTimeConstant(*t.ifTrue()) &&
4246         Analysis::IsCompileTimeConstant(*t.ifFalse())) {
4247         // both true and false are constants, can just use OpSelect
4248         SpvId result = this->nextId(nullptr);
4249         SpvId trueId = this->writeExpression(*t.ifTrue(), out);
4250         SpvId falseId = this->writeExpression(*t.ifFalse(), out);
4251         this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId,
4252                                out);
4253         return result;
4254     }
4255 
4256     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4257 
4258     // was originally using OpPhi to choose the result, but for some reason that is crashing on
4259     // Adreno. Switched to storing the result in a temp variable as glslang does.
4260     SpvId var = this->nextId(nullptr);
4261     this->writeInstruction(SpvOpVariable, this->getPointerType(type, StorageClass::kFunction),
4262                            var, SpvStorageClassFunction, fVariableBuffer);
4263     SpvId trueLabel = this->nextId(nullptr);
4264     SpvId falseLabel = this->nextId(nullptr);
4265     SpvId end = this->nextId(nullptr);
4266     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4267     this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
4268     this->writeLabel(trueLabel, kBranchIsOnPreviousLine, out);
4269     this->writeOpStore(StorageClass::kFunction, var, this->writeExpression(*t.ifTrue(), out), out);
4270     this->writeInstruction(SpvOpBranch, end, out);
4271     this->writeLabel(falseLabel, kBranchIsAbove, conditionalOps, out);
4272     this->writeOpStore(StorageClass::kFunction, var, this->writeExpression(*t.ifFalse(), out), out);
4273     this->writeInstruction(SpvOpBranch, end, out);
4274     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4275     SpvId result = this->nextId(&type);
4276     this->writeInstruction(SpvOpLoad, this->getType(type), result, var, out);
4277 
4278     return result;
4279 }
4280 
writePrefixExpression(const PrefixExpression & p,OutputStream & out)4281 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
4282     const Type& type = p.type();
4283     if (p.getOperator().kind() == Operator::Kind::MINUS) {
4284         SpvOp_ negateOp = pick_by_type(type, SpvOpFNegate, SpvOpSNegate, SpvOpSNegate, SpvOpUndef);
4285         SkASSERT(negateOp != SpvOpUndef);
4286         SpvId expr = this->writeExpression(*p.operand(), out);
4287         if (type.isMatrix()) {
4288             return this->writeComponentwiseMatrixUnary(type, expr, negateOp, out);
4289         }
4290         SpvId result = this->nextId(&type);
4291         SpvId typeId = this->getType(type);
4292         this->writeInstruction(negateOp, typeId, result, expr, out);
4293         return result;
4294     }
4295     switch (p.getOperator().kind()) {
4296         case Operator::Kind::PLUS:
4297             return this->writeExpression(*p.operand(), out);
4298 
4299         case Operator::Kind::PLUSPLUS: {
4300             std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4301             SpvId one = this->writeLiteral(1.0, type.componentType());
4302             one = this->splat(type, one, out);
4303             SpvId result = this->writeBinaryOperationComponentwiseIfMatrix(type, type,
4304                                                                            lv->load(out), one,
4305                                                                            SpvOpFAdd, SpvOpIAdd,
4306                                                                            SpvOpIAdd, SpvOpUndef,
4307                                                                            out);
4308             lv->store(result, out);
4309             return result;
4310         }
4311         case Operator::Kind::MINUSMINUS: {
4312             std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4313             SpvId one = this->writeLiteral(1.0, type.componentType());
4314             one = this->splat(type, one, out);
4315             SpvId result = this->writeBinaryOperationComponentwiseIfMatrix(type, type,
4316                                                                            lv->load(out), one,
4317                                                                            SpvOpFSub, SpvOpISub,
4318                                                                            SpvOpISub, SpvOpUndef,
4319                                                                            out);
4320             lv->store(result, out);
4321             return result;
4322         }
4323         case Operator::Kind::LOGICALNOT: {
4324             SkASSERT(p.operand()->type().isBoolean());
4325             SpvId result = this->nextId(nullptr);
4326             this->writeInstruction(SpvOpLogicalNot, this->getType(type), result,
4327                                    this->writeExpression(*p.operand(), out), out);
4328             return result;
4329         }
4330         case Operator::Kind::BITWISENOT: {
4331             SpvId result = this->nextId(nullptr);
4332             this->writeInstruction(SpvOpNot, this->getType(type), result,
4333                                    this->writeExpression(*p.operand(), out), out);
4334             return result;
4335         }
4336         default:
4337             SkDEBUGFAILF("unsupported prefix expression: %s",
4338                          p.description(OperatorPrecedence::kExpression).c_str());
4339             return NA;
4340     }
4341 }
4342 
writePostfixExpression(const PostfixExpression & p,OutputStream & out)4343 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
4344     const Type& type = p.type();
4345     std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4346     SpvId result = lv->load(out);
4347     SpvId one = this->writeLiteral(1.0, type.componentType());
4348     one = this->splat(type, one, out);
4349     switch (p.getOperator().kind()) {
4350         case Operator::Kind::PLUSPLUS: {
4351             SpvId temp = this->writeBinaryOperationComponentwiseIfMatrix(type, type, result, one,
4352                                                                          SpvOpFAdd, SpvOpIAdd,
4353                                                                          SpvOpIAdd, SpvOpUndef,
4354                                                                          out);
4355             lv->store(temp, out);
4356             return result;
4357         }
4358         case Operator::Kind::MINUSMINUS: {
4359             SpvId temp = this->writeBinaryOperationComponentwiseIfMatrix(type, type, result, one,
4360                                                                          SpvOpFSub, SpvOpISub,
4361                                                                          SpvOpISub, SpvOpUndef,
4362                                                                          out);
4363             lv->store(temp, out);
4364             return result;
4365         }
4366         default:
4367             SkDEBUGFAILF("unsupported postfix expression %s",
4368                          p.description(OperatorPrecedence::kExpression).c_str());
4369             return NA;
4370     }
4371 }
4372 
writeLiteral(const Literal & l)4373 SpvId SPIRVCodeGenerator::writeLiteral(const Literal& l) {
4374     return this->writeLiteral(l.value(), l.type());
4375 }
4376 
writeLiteral(double value,const Type & type)4377 SpvId SPIRVCodeGenerator::writeLiteral(double value, const Type& type) {
4378     switch (type.numberKind()) {
4379         case Type::NumberKind::kFloat: {
4380             float floatVal = value;
4381             int32_t valueBits;
4382             memcpy(&valueBits, &floatVal, sizeof(valueBits));
4383             return this->writeOpConstant(type, valueBits);
4384         }
4385         case Type::NumberKind::kBoolean: {
4386             return value ? this->writeOpConstantTrue(type)
4387                          : this->writeOpConstantFalse(type);
4388         }
4389         default: {
4390             return this->writeOpConstant(type, (SKSL_INT)value);
4391         }
4392     }
4393 }
4394 
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)4395 void SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
4396     SpvId result = fFunctionMap[{&f, fActiveSpecializationIndex}];
4397     SpvId returnTypeId = this->getType(f.returnType());
4398     SpvId functionTypeId = this->getFunctionType(f);
4399     this->writeInstruction(SpvOpFunction, returnTypeId, result,
4400                            SpvFunctionControlMaskNone, functionTypeId, out);
4401     std::string mangledName = f.mangledName();
4402 
4403     // For specialized functions, tack on `_param1_param2` to the function name.
4404     Analysis::GetParameterMappingsForFunction(
4405             f, fSpecializationInfo, fActiveSpecializationIndex,
4406             [&](int, const Variable*, const Expression* expr) {
4407                 std::string name = expr->description();
4408                 std::replace_if(name.begin(), name.end(), [](char c) { return !isalnum(c); }, '_');
4409 
4410                 mangledName += "_" + name;
4411             });
4412 
4413     this->writeInstruction(SpvOpName,
4414                            result,
4415                            std::string_view(mangledName.c_str(), mangledName.size()),
4416                            fNameBuffer);
4417     for (const Variable* parameter : f.parameters()) {
4418         const Variable* specializedVar = nullptr;
4419         if (fActiveSpecialization) {
4420             if (const Expression** specializedExpr = fActiveSpecialization->find(parameter)) {
4421                 if ((*specializedExpr)->is<FieldAccess>()) {
4422                     continue;
4423                 }
4424                 SkASSERT((*specializedExpr)->is<VariableReference>());
4425                 specializedVar = (*specializedExpr)->as<VariableReference>().variable();
4426             }
4427         }
4428 
4429         if (fUseTextureSamplerPairs && parameter->type().isSampler()) {
4430             auto [texture, sampler] = this->synthesizeTextureAndSampler(*parameter);
4431 
4432             SpvId textureId = this->nextId(nullptr);
4433             fVariableMap.set(texture, textureId);
4434 
4435             SpvId textureType = this->getFunctionParameterType(texture->type(), texture->layout());
4436             this->writeInstruction(SpvOpFunctionParameter, textureType, textureId, out);
4437 
4438             if (specializedVar) {
4439                 const auto* p = fSynthesizedSamplerMap.find(specializedVar);
4440                 SkASSERT(p);
4441                 const SpvId* uniformId = fVariableMap.find((*p)->fSampler.get());
4442                 SkASSERT(uniformId);
4443                 fVariableMap.set(sampler, *uniformId);
4444             } else {
4445                 SpvId samplerId = this->nextId(nullptr);
4446                 fVariableMap.set(sampler, samplerId);
4447 
4448                 SpvId samplerType =
4449                         this->getFunctionParameterType(sampler->type(), kDefaultTypeLayout);
4450                 this->writeInstruction(SpvOpFunctionParameter, samplerType, samplerId, out);
4451             }
4452         } else {
4453             if (specializedVar) {
4454                 const SpvId* uniformId = fVariableMap.find(specializedVar);
4455                 SkASSERT(uniformId);
4456                 fVariableMap.set(parameter, *uniformId);
4457             } else {
4458                 SpvId id = this->nextId(nullptr);
4459                 fVariableMap.set(parameter, id);
4460 
4461                 SpvId type = this->getFunctionParameterType(parameter->type(), parameter->layout());
4462                 this->writeInstruction(SpvOpFunctionParameter, type, id, out);
4463             }
4464         }
4465     }
4466 }
4467 
writeFunction(const FunctionDefinition & f,OutputStream & out)4468 void SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
4469     if (const Analysis::Specializations* specializations =
4470                 fSpecializationInfo.fSpecializationMap.find(&f.declaration())) {
4471         for (int i = 0; i < specializations->size(); i++) {
4472             this->writeFunctionInstantiation(f, i, &specializations->at(i), out);
4473         }
4474     } else {
4475         this->writeFunctionInstantiation(f,
4476                                          Analysis::kUnspecialized,
4477                                          /*specializedParams=*/nullptr,
4478                                          out);
4479     }
4480 }
4481 
writeFunctionInstantiation(const FunctionDefinition & f,Analysis::SpecializationIndex specializationIndex,const Analysis::SpecializedParameters * specializedParams,OutputStream & out)4482 void SPIRVCodeGenerator::writeFunctionInstantiation(
4483         const FunctionDefinition& f,
4484         Analysis::SpecializationIndex specializationIndex,
4485         const Analysis::SpecializedParameters* specializedParams,
4486         OutputStream& out) {
4487     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4488 
4489     fVariableBuffer.reset();
4490     fActiveSpecialization = specializedParams;
4491     fActiveSpecializationIndex = specializationIndex;
4492     this->writeFunctionStart(f.declaration(), out);
4493     fCurrentBlock = 0;
4494     this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
4495     StringStream bodyBuffer;
4496     this->writeBlock(f.body()->as<Block>(), bodyBuffer);
4497     fActiveSpecialization = nullptr;
4498     fActiveSpecializationIndex = Analysis::kUnspecialized;
4499     write_stringstream(fVariableBuffer, out);
4500     if (f.declaration().isMain()) {
4501         write_stringstream(fGlobalInitializersBuffer, out);
4502     }
4503     write_stringstream(bodyBuffer, out);
4504     if (fCurrentBlock) {
4505         if (f.declaration().returnType().isVoid()) {
4506             this->writeInstruction(SpvOpReturn, out);
4507         } else {
4508             this->writeInstruction(SpvOpUnreachable, out);
4509         }
4510     }
4511     this->writeInstruction(SpvOpFunctionEnd, out);
4512     this->pruneConditionalOps(conditionalOps);
4513 }
4514 
writeLayout(const Layout & layout,SpvId target,Position pos)4515 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, Position pos) {
4516     bool isPushConstant = SkToBool(layout.fFlags & LayoutFlag::kPushConstant);
4517     if (layout.fLocation >= 0) {
4518         this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
4519                                fDecorationBuffer);
4520     }
4521     if (layout.fBinding >= 0) {
4522         if (isPushConstant) {
4523             fContext.fErrors->error(pos, "Can't apply 'binding' to push constants");
4524         } else {
4525             this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
4526                                    fDecorationBuffer);
4527         }
4528     }
4529     if (layout.fIndex >= 0) {
4530         this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
4531                                fDecorationBuffer);
4532     }
4533     if (layout.fSet >= 0) {
4534         if (isPushConstant) {
4535             fContext.fErrors->error(pos, "Can't apply 'set' to push constants");
4536         } else {
4537             this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
4538                                    fDecorationBuffer);
4539         }
4540     }
4541     if (layout.fInputAttachmentIndex >= 0) {
4542         this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
4543                                layout.fInputAttachmentIndex, fDecorationBuffer);
4544         fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
4545     }
4546     if (layout.fBuiltin >= 0 && (layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
4547                                  layout.fBuiltin != SK_SECONDARYFRAGCOLOR_BUILTIN)) {
4548             this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
4549                                    fDecorationBuffer);
4550     }
4551 }
4552 
writeFieldLayout(const Layout & layout,SpvId target,int member)4553 void SPIRVCodeGenerator::writeFieldLayout(const Layout& layout, SpvId target, int member) {
4554     // 'binding' and 'set' can not be applied to struct members
4555     SkASSERT(layout.fBinding == -1);
4556     SkASSERT(layout.fSet == -1);
4557     if (layout.fLocation >= 0) {
4558         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
4559                                layout.fLocation, fDecorationBuffer);
4560     }
4561     if (layout.fIndex >= 0) {
4562         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
4563                                layout.fIndex, fDecorationBuffer);
4564     }
4565     if (layout.fInputAttachmentIndex >= 0) {
4566         this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
4567                                layout.fInputAttachmentIndex, fDecorationBuffer);
4568     }
4569     if (layout.fBuiltin >= 0) {
4570         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
4571                                layout.fBuiltin, fDecorationBuffer);
4572     }
4573 }
4574 
memoryLayoutForStorageClass(StorageClass storageClass)4575 MemoryLayout SPIRVCodeGenerator::memoryLayoutForStorageClass(StorageClass storageClass) {
4576     return storageClass == StorageClass::kPushConstant ||
4577            storageClass == StorageClass::kStorageBuffer
4578                                 ? MemoryLayout(MemoryLayout::Standard::k430)
4579                                 : fDefaultMemoryLayout;
4580 }
4581 
memoryLayoutForVariable(const Variable & v) const4582 MemoryLayout SPIRVCodeGenerator::memoryLayoutForVariable(const Variable& v) const {
4583     bool pushConstant = SkToBool(v.layout().fFlags & LayoutFlag::kPushConstant);
4584     bool buffer = v.modifierFlags().isBuffer();
4585     return pushConstant || buffer ? MemoryLayout(MemoryLayout::Standard::k430)
4586                                   : fDefaultMemoryLayout;
4587 }
4588 
writeInterfaceBlock(const InterfaceBlock & intf,bool appendRTFlip)4589 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip) {
4590     MemoryLayout memoryLayout = this->memoryLayoutForVariable(*intf.var());
4591     SpvId result = this->nextId(nullptr);
4592     const Variable& intfVar = *intf.var();
4593     const Type& type = intfVar.type();
4594     if (!memoryLayout.isSupported(type)) {
4595         fContext.fErrors->error(type.fPosition, "type '" + type.displayName() +
4596                                                 "' is not permitted here");
4597         return this->nextId(nullptr);
4598     }
4599     StorageClass storageClass =
4600             get_storage_class_for_global_variable(intfVar, StorageClass::kFunction);
4601     if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None && appendRTFlip &&
4602         !fWroteRTFlip && type.isStruct()) {
4603         // We can only have one interface block (because we use push_constant and that is limited
4604         // to one per program), so we need to append rtflip to this one rather than synthesize an
4605         // entirely new block when the variable is referenced. And we can't modify the existing
4606         // block, so we instead create a modified copy of it and write that.
4607         SkSpan<const Field> fieldSpan = type.fields();
4608         TArray<Field> fields(fieldSpan.data(), fieldSpan.size());
4609         fields.emplace_back(Position(),
4610                             Layout(LayoutFlag::kNone,
4611                                    /*location=*/-1,
4612                                    fProgram.fConfig->fSettings.fRTFlipOffset,
4613                                    /*binding=*/-1,
4614                                    /*index=*/-1,
4615                                    /*set=*/-1,
4616                                    /*builtin=*/-1,
4617                                    /*inputAttachmentIndex=*/-1),
4618                             ModifierFlag::kNone,
4619                             SKSL_RTFLIP_NAME,
4620                             fContext.fTypes.fFloat2.get());
4621         {
4622             AutoAttachPoolToThread attach(fProgram.fPool.get());
4623             const Type* rtFlipStructType = fProgram.fSymbols->takeOwnershipOfSymbol(
4624                     Type::MakeStructType(fContext,
4625                                          type.fPosition,
4626                                          type.name(),
4627                                          std::move(fields),
4628                                          /*interfaceBlock=*/true));
4629             Variable* modifiedVar = fProgram.fSymbols->takeOwnershipOfSymbol(
4630                     Variable::Make(intfVar.fPosition,
4631                                    intfVar.modifiersPosition(),
4632                                    intfVar.layout(),
4633                                    intfVar.modifierFlags(),
4634                                    rtFlipStructType,
4635                                    intfVar.name(),
4636                                    /*mangledName=*/"",
4637                                    intfVar.isBuiltin(),
4638                                    intfVar.storage()));
4639             InterfaceBlock modifiedCopy(intf.fPosition, modifiedVar);
4640             result = this->writeInterfaceBlock(modifiedCopy, /*appendRTFlip=*/false);
4641             fProgram.fSymbols->add(fContext, std::make_unique<FieldSymbol>(
4642                     Position(), modifiedVar, rtFlipStructType->fields().size() - 1));
4643         }
4644         fVariableMap.set(&intfVar, result);
4645         fWroteRTFlip = true;
4646         return result;
4647     }
4648     SpvId typeId = this->getType(type, kDefaultTypeLayout, memoryLayout);
4649     if (intfVar.layout().fBuiltin == -1) {
4650         // Note: In SPIR-V 1.3, a storage buffer can be declared with the "StorageBuffer"
4651         // storage class and the "Block" decoration and the <1.3 approach we use here ("Uniform"
4652         // storage class and the "BufferBlock" decoration) is deprecated. Since we target SPIR-V
4653         // 1.0, we have to use the deprecated approach which is well supported in Vulkan and
4654         // addresses SkSL use cases (notably SkSL currently doesn't support pointer features that
4655         // would benefit from SPV_KHR_variable_pointers capabilities).
4656         bool isStorageBuffer = intfVar.modifierFlags().isBuffer();
4657         this->writeInstruction(SpvOpDecorate,
4658                                typeId,
4659                                isStorageBuffer ? SpvDecorationBufferBlock : SpvDecorationBlock,
4660                                fDecorationBuffer);
4661     }
4662     SpvId ptrType = this->nextId(nullptr);
4663     this->writeInstruction(SpvOpTypePointer, ptrType,
4664                            get_storage_class_spv_id(storageClass), typeId, fConstantBuffer);
4665     this->writeInstruction(SpvOpVariable, ptrType, result,
4666                            get_storage_class_spv_id(storageClass), fConstantBuffer);
4667     Layout layout = intfVar.layout();
4668     if ((storageClass == StorageClass::kUniform ||
4669                 storageClass == StorageClass::kStorageBuffer) && layout.fSet < 0) {
4670         layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
4671     }
4672     this->writeLayout(layout, result, intfVar.fPosition);
4673     fVariableMap.set(&intfVar, result);
4674     return result;
4675 }
4676 
4677 // This function determines whether to skip an OpVariable (of pointer type) declaration for
4678 // compile-time constant scalars and vectors which we turn into OpConstant/OpConstantComposite and
4679 // always reference by value.
4680 //
4681 // Accessing a matrix or array member with a dynamic index requires the use of OpAccessChain which
4682 // requires a base operand of pointer type. However, a vector can always be accessed by value using
4683 // OpVectorExtractDynamic (see writeIndexExpression).
4684 //
4685 // This is why we always emit an OpVariable for all non-scalar and non-vector types in case they get
4686 // accessed via a dynamic index.
is_vardecl_compile_time_constant(const VarDeclaration & varDecl)4687 static bool is_vardecl_compile_time_constant(const VarDeclaration& varDecl) {
4688     return varDecl.var()->modifierFlags().isConst() &&
4689            (varDecl.var()->type().isScalar() || varDecl.var()->type().isVector()) &&
4690            (ConstantFolder::GetConstantValueOrNull(*varDecl.value()) ||
4691             Analysis::IsCompileTimeConstant(*varDecl.value()));
4692 }
4693 
writeGlobalVarDeclaration(ProgramKind kind,const VarDeclaration & varDecl)4694 bool SPIRVCodeGenerator::writeGlobalVarDeclaration(ProgramKind kind,
4695                                                    const VarDeclaration& varDecl) {
4696     const Variable* var = varDecl.var();
4697     const LayoutFlags backendFlags = var->layout().fFlags & LayoutFlag::kAllBackends;
4698     const LayoutFlags kPermittedBackendFlags =
4699             LayoutFlag::kVulkan | LayoutFlag::kWebGPU | LayoutFlag::kDirect3D;
4700     if (backendFlags & ~kPermittedBackendFlags) {
4701         fContext.fErrors->error(var->fPosition, "incompatible backend flag in SPIR-V codegen");
4702         return false;
4703     }
4704 
4705     // If this global variable is a compile-time constant then we'll emit OpConstant or
4706     // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
4707     if (is_vardecl_compile_time_constant(varDecl)) {
4708         return true;
4709     }
4710 
4711     StorageClass storageClass =
4712             get_storage_class_for_global_variable(*var, StorageClass::kPrivate);
4713     if (storageClass == StorageClass::kUniform || storageClass == StorageClass::kStorageBuffer) {
4714         // Top-level uniforms are emitted in writeUniformBuffer.
4715         fTopLevelUniforms.push_back(&varDecl);
4716         return true;
4717     }
4718 
4719     if (fUseTextureSamplerPairs && var->type().isSampler()) {
4720         if (var->layout().fTexture == -1 || var->layout().fSampler == -1) {
4721             fContext.fErrors->error(var->fPosition, "selected backend requires separate texture "
4722                                                     "and sampler indices");
4723             return false;
4724         }
4725         SkASSERT(storageClass == StorageClass::kUniformConstant);
4726 
4727         auto [texture, sampler] = this->synthesizeTextureAndSampler(*var);
4728         this->writeGlobalVar(kind, storageClass, *texture);
4729         this->writeGlobalVar(kind, storageClass, *sampler);
4730 
4731         return true;
4732     }
4733 
4734     SpvId id = this->writeGlobalVar(kind, storageClass, *var);
4735     if (id != NA && varDecl.value()) {
4736         SkASSERT(!fCurrentBlock);
4737         fCurrentBlock = NA;
4738         SpvId value = this->writeExpression(*varDecl.value(), fGlobalInitializersBuffer);
4739         this->writeOpStore(storageClass, id, value, fGlobalInitializersBuffer);
4740         fCurrentBlock = 0;
4741     }
4742     return true;
4743 }
4744 
writeGlobalVar(ProgramKind kind,StorageClass storageClass,const Variable & var)4745 SpvId SPIRVCodeGenerator::writeGlobalVar(ProgramKind kind,
4746                                          StorageClass storageClass,
4747                                          const Variable& var) {
4748     Layout layout = var.layout();
4749     const ModifierFlags flags = var.modifierFlags();
4750     const Type* type = &var.type();
4751     switch (layout.fBuiltin) {
4752         case SK_FRAGCOLOR_BUILTIN:
4753         case SK_SECONDARYFRAGCOLOR_BUILTIN:
4754             if (!ProgramConfig::IsFragment(kind)) {
4755                 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
4756                 return NA;
4757             }
4758             break;
4759 
4760         case SK_SAMPLEMASKIN_BUILTIN:
4761         case SK_SAMPLEMASK_BUILTIN:
4762             // SkSL exposes this as a `uint` but SPIR-V, like GLSL, uses an array of signed `uint`
4763             // decorated with SpvBuiltinSampleMask.
4764             type = fSynthetics.addArrayDimension(fContext, type, /*arraySize=*/1);
4765             layout.fBuiltin = SpvBuiltInSampleMask;
4766             break;
4767     }
4768 
4769     // Add this global to the variable map.
4770     SpvId id = this->nextId(type);
4771     fVariableMap.set(&var, id);
4772 
4773     if (layout.fSet < 0 && storageClass == StorageClass::kUniformConstant) {
4774         layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
4775     }
4776 
4777     SpvId typeId = this->getPointerType(*type,
4778                                         layout,
4779                                         this->memoryLayoutForStorageClass(storageClass),
4780                                         storageClass);
4781     this->writeInstruction(SpvOpVariable, typeId, id,
4782                            get_storage_class_spv_id(storageClass), fConstantBuffer);
4783     this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
4784     this->writeLayout(layout, id, var.fPosition);
4785     if (flags & ModifierFlag::kFlat) {
4786         this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
4787     }
4788     if (flags & ModifierFlag::kNoPerspective) {
4789         this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
4790                                fDecorationBuffer);
4791     }
4792     if (flags.isWriteOnly()) {
4793         this->writeInstruction(SpvOpDecorate, id, SpvDecorationNonReadable, fDecorationBuffer);
4794     } else if (flags.isReadOnly()) {
4795         this->writeInstruction(SpvOpDecorate, id, SpvDecorationNonWritable, fDecorationBuffer);
4796     }
4797 
4798     return id;
4799 }
4800 
writeVarDeclaration(const VarDeclaration & varDecl,OutputStream & out)4801 void SPIRVCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, OutputStream& out) {
4802     // If this variable is a compile-time constant then we'll emit OpConstant or
4803     // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
4804     if (is_vardecl_compile_time_constant(varDecl)) {
4805         return;
4806     }
4807 
4808     const Variable* var = varDecl.var();
4809     SpvId id = this->nextId(&var->type());
4810     fVariableMap.set(var, id);
4811     SpvId type = this->getPointerType(var->type(), StorageClass::kFunction);
4812     this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
4813     this->writeInstruction(SpvOpName, id, var->name(), fNameBuffer);
4814     if (varDecl.value()) {
4815         SpvId value = this->writeExpression(*varDecl.value(), out);
4816         this->writeOpStore(StorageClass::kFunction, id, value, out);
4817     }
4818 }
4819 
writeStatement(const Statement & s,OutputStream & out)4820 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
4821     switch (s.kind()) {
4822         case Statement::Kind::kNop:
4823             break;
4824         case Statement::Kind::kBlock:
4825             this->writeBlock(s.as<Block>(), out);
4826             break;
4827         case Statement::Kind::kExpression:
4828             this->writeExpression(*s.as<ExpressionStatement>().expression(), out);
4829             break;
4830         case Statement::Kind::kReturn:
4831             this->writeReturnStatement(s.as<ReturnStatement>(), out);
4832             break;
4833         case Statement::Kind::kVarDeclaration:
4834             this->writeVarDeclaration(s.as<VarDeclaration>(), out);
4835             break;
4836         case Statement::Kind::kIf:
4837             this->writeIfStatement(s.as<IfStatement>(), out);
4838             break;
4839         case Statement::Kind::kFor:
4840             this->writeForStatement(s.as<ForStatement>(), out);
4841             break;
4842         case Statement::Kind::kDo:
4843             this->writeDoStatement(s.as<DoStatement>(), out);
4844             break;
4845         case Statement::Kind::kSwitch:
4846             this->writeSwitchStatement(s.as<SwitchStatement>(), out);
4847             break;
4848         case Statement::Kind::kBreak:
4849             this->writeInstruction(SpvOpBranch, fBreakTarget.back(), out);
4850             break;
4851         case Statement::Kind::kContinue:
4852             this->writeInstruction(SpvOpBranch, fContinueTarget.back(), out);
4853             break;
4854         case Statement::Kind::kDiscard:
4855             this->writeInstruction(SpvOpKill, out);
4856             break;
4857         default:
4858             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
4859             break;
4860     }
4861 }
4862 
writeBlock(const Block & b,OutputStream & out)4863 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
4864     for (const std::unique_ptr<Statement>& stmt : b.children()) {
4865         this->writeStatement(*stmt, out);
4866     }
4867 }
4868 
getConditionalOpCounts()4869 SPIRVCodeGenerator::ConditionalOpCounts SPIRVCodeGenerator::getConditionalOpCounts() {
4870     return {fReachableOps.size(), fStoreOps.size()};
4871 }
4872 
pruneConditionalOps(ConditionalOpCounts ops)4873 void SPIRVCodeGenerator::pruneConditionalOps(ConditionalOpCounts ops) {
4874     // Remove ops which are no longer reachable.
4875     while (fReachableOps.size() > ops.numReachableOps) {
4876         SpvId prunableSpvId = fReachableOps.back();
4877         const Instruction* prunableOp = fSpvIdCache.find(prunableSpvId);
4878 
4879         if (prunableOp) {
4880             fOpCache.remove(*prunableOp);
4881             fSpvIdCache.remove(prunableSpvId);
4882         } else {
4883             SkDEBUGFAIL("reachable-op list contains unrecognized SpvId");
4884         }
4885 
4886         fReachableOps.pop_back();
4887     }
4888 
4889     // Remove any cached stores that occurred during the conditional block.
4890     while (fStoreOps.size() > ops.numStoreOps) {
4891         if (fStoreCache.find(fStoreOps.back())) {
4892             fStoreCache.remove(fStoreOps.back());
4893         }
4894         fStoreOps.pop_back();
4895     }
4896 }
4897 
writeIfStatement(const IfStatement & stmt,OutputStream & out)4898 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
4899     SpvId test = this->writeExpression(*stmt.test(), out);
4900     SpvId ifTrue = this->nextId(nullptr);
4901     SpvId ifFalse = this->nextId(nullptr);
4902 
4903     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4904 
4905     if (stmt.ifFalse()) {
4906         SpvId end = this->nextId(nullptr);
4907         this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4908         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
4909         this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
4910         this->writeStatement(*stmt.ifTrue(), out);
4911         if (fCurrentBlock) {
4912             this->writeInstruction(SpvOpBranch, end, out);
4913         }
4914         this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
4915         this->writeStatement(*stmt.ifFalse(), out);
4916         if (fCurrentBlock) {
4917             this->writeInstruction(SpvOpBranch, end, out);
4918         }
4919         this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4920     } else {
4921         this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
4922         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
4923         this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
4924         this->writeStatement(*stmt.ifTrue(), out);
4925         if (fCurrentBlock) {
4926             this->writeInstruction(SpvOpBranch, ifFalse, out);
4927         }
4928         this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
4929     }
4930 }
4931 
writeForStatement(const ForStatement & f,OutputStream & out)4932 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
4933     if (f.initializer()) {
4934         this->writeStatement(*f.initializer(), out);
4935     }
4936 
4937     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4938 
4939     // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
4940     // in the context of linear straight-line execution. If we wanted to be more clever, we could
4941     // only invalidate store cache entries for variables affected by the loop body, but for now we
4942     // simply clear the entire cache whenever branching occurs.
4943     SpvId header = this->nextId(nullptr);
4944     SpvId start = this->nextId(nullptr);
4945     SpvId body = this->nextId(nullptr);
4946     SpvId next = this->nextId(nullptr);
4947     fContinueTarget.push_back(next);
4948     SpvId end = this->nextId(nullptr);
4949     fBreakTarget.push_back(end);
4950     this->writeInstruction(SpvOpBranch, header, out);
4951     this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
4952     this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
4953     this->writeInstruction(SpvOpBranch, start, out);
4954     this->writeLabel(start, kBranchIsOnPreviousLine, out);
4955     if (f.test()) {
4956         SpvId test = this->writeExpression(*f.test(), out);
4957         this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
4958     } else {
4959         this->writeInstruction(SpvOpBranch, body, out);
4960     }
4961     this->writeLabel(body, kBranchIsOnPreviousLine, out);
4962     this->writeStatement(*f.statement(), out);
4963     if (fCurrentBlock) {
4964         this->writeInstruction(SpvOpBranch, next, out);
4965     }
4966     this->writeLabel(next, kBranchIsAbove, conditionalOps, out);
4967     if (f.next()) {
4968         this->writeExpression(*f.next(), out);
4969     }
4970     this->writeInstruction(SpvOpBranch, header, out);
4971     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4972     fBreakTarget.pop_back();
4973     fContinueTarget.pop_back();
4974 }
4975 
writeDoStatement(const DoStatement & d,OutputStream & out)4976 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
4977     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4978 
4979     // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
4980     // in the context of linear straight-line execution. If we wanted to be more clever, we could
4981     // only invalidate store cache entries for variables affected by the loop body, but for now we
4982     // simply clear the entire cache whenever branching occurs.
4983     SpvId header = this->nextId(nullptr);
4984     SpvId start = this->nextId(nullptr);
4985     SpvId next = this->nextId(nullptr);
4986     SpvId continueTarget = this->nextId(nullptr);
4987     fContinueTarget.push_back(continueTarget);
4988     SpvId end = this->nextId(nullptr);
4989     fBreakTarget.push_back(end);
4990     this->writeInstruction(SpvOpBranch, header, out);
4991     this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
4992     this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
4993     this->writeInstruction(SpvOpBranch, start, out);
4994     this->writeLabel(start, kBranchIsOnPreviousLine, out);
4995     this->writeStatement(*d.statement(), out);
4996     if (fCurrentBlock) {
4997         this->writeInstruction(SpvOpBranch, next, out);
4998         this->writeLabel(next, kBranchIsOnPreviousLine, out);
4999         this->writeInstruction(SpvOpBranch, continueTarget, out);
5000     }
5001     this->writeLabel(continueTarget, kBranchIsAbove, conditionalOps, out);
5002     SpvId test = this->writeExpression(*d.test(), out);
5003     this->writeInstruction(SpvOpBranchConditional, test, header, end, out);
5004     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
5005     fBreakTarget.pop_back();
5006     fContinueTarget.pop_back();
5007 }
5008 
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)5009 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
5010     SpvId value = this->writeExpression(*s.value(), out);
5011 
5012     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
5013 
5014     // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
5015     // in the context of linear straight-line execution. If we wanted to be more clever, we could
5016     // only invalidate store cache entries for variables affected by the switch body, but for now we
5017     // simply clear the entire cache whenever branching occurs.
5018     TArray<SpvId> labels;
5019     SpvId end = this->nextId(nullptr);
5020     SpvId defaultLabel = end;
5021     fBreakTarget.push_back(end);
5022     int size = 3;
5023     const StatementArray& cases = s.cases();
5024     for (const std::unique_ptr<Statement>& stmt : cases) {
5025         const SwitchCase& c = stmt->as<SwitchCase>();
5026         SpvId label = this->nextId(nullptr);
5027         labels.push_back(label);
5028         if (!c.isDefault()) {
5029             size += 2;
5030         } else {
5031             defaultLabel = label;
5032         }
5033     }
5034 
5035     // We should have exactly one label for each case.
5036     SkASSERT(labels.size() == cases.size());
5037 
5038     // Collapse adjacent switch-cases into one; that is, reduce `case 1: case 2: case 3:` into a
5039     // single OpLabel. The Tint SPIR-V reader does not support switch-case fallthrough, but it
5040     // does support multiple switch-cases branching to the same label.
5041     SkBitSet caseIsCollapsed(cases.size());
5042     for (int index = cases.size() - 2; index >= 0; index--) {
5043         if (cases[index]->as<SwitchCase>().statement()->isEmpty()) {
5044             caseIsCollapsed.set(index);
5045             labels[index] = labels[index + 1];
5046         }
5047     }
5048 
5049     labels.push_back(end);
5050 
5051     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
5052     this->writeOpCode(SpvOpSwitch, size, out);
5053     this->writeWord(value, out);
5054     this->writeWord(defaultLabel, out);
5055     for (int i = 0; i < cases.size(); ++i) {
5056         const SwitchCase& c = cases[i]->as<SwitchCase>();
5057         if (c.isDefault()) {
5058             continue;
5059         }
5060         this->writeWord(c.value(), out);
5061         this->writeWord(labels[i], out);
5062     }
5063     for (int i = 0; i < cases.size(); ++i) {
5064         if (caseIsCollapsed.test(i)) {
5065             continue;
5066         }
5067         const SwitchCase& c = cases[i]->as<SwitchCase>();
5068         if (i == 0) {
5069             this->writeLabel(labels[i], kBranchIsOnPreviousLine, out);
5070         } else {
5071             this->writeLabel(labels[i], kBranchIsAbove, conditionalOps, out);
5072         }
5073         this->writeStatement(*c.statement(), out);
5074         if (fCurrentBlock) {
5075             this->writeInstruction(SpvOpBranch, labels[i + 1], out);
5076         }
5077     }
5078     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
5079     fBreakTarget.pop_back();
5080 }
5081 
writeReturnStatement(const ReturnStatement & r,OutputStream & out)5082 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
5083     if (r.expression()) {
5084         this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.expression(), out),
5085                                out);
5086     } else {
5087         this->writeInstruction(SpvOpReturn, out);
5088     }
5089 }
5090 
5091 // Given any function, returns the top-level symbol table (OUTSIDE of the function's scope).
get_top_level_symbol_table(const FunctionDeclaration & anyFunc)5092 static SymbolTable* get_top_level_symbol_table(const FunctionDeclaration& anyFunc) {
5093     return anyFunc.definition()->body()->as<Block>().symbolTable()->fParent;
5094 }
5095 
writeEntrypointAdapter(const FunctionDeclaration & main)5096 SPIRVCodeGenerator::EntrypointAdapter SPIRVCodeGenerator::writeEntrypointAdapter(
5097         const FunctionDeclaration& main) {
5098     // Our goal is to synthesize a tiny helper function which looks like this:
5099     //     void _entrypoint() { sk_FragColor = main(); }
5100 
5101     // Fish a symbol table out of main().
5102     SymbolTable* symbolTable = get_top_level_symbol_table(main);
5103 
5104     // Get `sk_FragColor` as a writable reference.
5105     const Symbol* skFragColorSymbol = symbolTable->find("sk_FragColor");
5106     SkASSERT(skFragColorSymbol);
5107     const Variable& skFragColorVar = skFragColorSymbol->as<Variable>();
5108     auto skFragColorRef = std::make_unique<VariableReference>(Position(), &skFragColorVar,
5109                                                               VariableReference::RefKind::kWrite);
5110 
5111     // TODO get secondary frag color as one as well?
5112 
5113     // Synthesize a call to the `main()` function.
5114     if (!main.returnType().matches(skFragColorRef->type())) {
5115         fContext.fErrors->error(main.fPosition, "SPIR-V does not support returning '" +
5116                 main.returnType().description() + "' from main()");
5117         return {};
5118     }
5119     ExpressionArray args;
5120     if (main.parameters().size() == 1) {
5121         if (!main.parameters()[0]->type().matches(*fContext.fTypes.fFloat2)) {
5122             fContext.fErrors->error(main.fPosition,
5123                     "SPIR-V does not support parameter of type '" +
5124                     main.parameters()[0]->type().description() + "' to main()");
5125             return {};
5126         }
5127         double kZero[2] = {0.0, 0.0};
5128         args.push_back(ConstructorCompound::MakeFromConstants(fContext, Position{},
5129                                                               *fContext.fTypes.fFloat2, kZero));
5130     }
5131     auto callMainFn = FunctionCall::Make(fContext, Position(), &main.returnType(),
5132                                          main, std::move(args));
5133 
5134     // Synthesize `skFragColor = main()` as a BinaryExpression.
5135     auto assignmentStmt = std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
5136             Position(),
5137             std::move(skFragColorRef),
5138             Operator::Kind::EQ,
5139             std::move(callMainFn),
5140             &main.returnType()));
5141 
5142     // Function bodies are always wrapped in a Block.
5143     StatementArray entrypointStmts;
5144     entrypointStmts.push_back(std::move(assignmentStmt));
5145     auto entrypointBlock = Block::Make(Position(), std::move(entrypointStmts),
5146                                        Block::Kind::kBracedScope, /*symbols=*/nullptr);
5147     // Declare an entrypoint function.
5148     EntrypointAdapter adapter;
5149     adapter.entrypointDecl =
5150             std::make_unique<FunctionDeclaration>(fContext,
5151                                                   Position(),
5152                                                   ModifierFlag::kNone,
5153                                                   "_entrypoint",
5154                                                   /*parameters=*/TArray<Variable*>{},
5155                                                   /*returnType=*/fContext.fTypes.fVoid.get(),
5156                                                   kNotIntrinsic);
5157     // Define it.
5158     adapter.entrypointDef = FunctionDefinition::Convert(fContext,
5159                                                         Position(),
5160                                                         *adapter.entrypointDecl,
5161                                                         std::move(entrypointBlock));
5162 
5163     adapter.entrypointDecl->setDefinition(adapter.entrypointDef.get());
5164     return adapter;
5165 }
5166 
writeUniformBuffer(SymbolTable * topLevelSymbolTable)5167 void SPIRVCodeGenerator::writeUniformBuffer(SymbolTable* topLevelSymbolTable) {
5168     SkASSERT(!fTopLevelUniforms.empty());
5169     static constexpr char kUniformBufferName[] = "_UniformBuffer";
5170 
5171     // Convert the list of top-level uniforms into a matching struct named _UniformBuffer, and build
5172     // a lookup table of variables to UniformBuffer field indices.
5173     TArray<Field> fields;
5174     fields.reserve_exact(fTopLevelUniforms.size());
5175     for (const VarDeclaration* topLevelUniform : fTopLevelUniforms) {
5176         const Variable* var = topLevelUniform->var();
5177         fTopLevelUniformMap.set(var, (int)fields.size());
5178         ModifierFlags flags = var->modifierFlags() & ~ModifierFlag::kUniform;
5179         fields.emplace_back(var->fPosition, var->layout(), flags, var->name(), &var->type());
5180     }
5181     fUniformBuffer.fStruct = Type::MakeStructType(fContext,
5182                                                   Position(),
5183                                                   kUniformBufferName,
5184                                                   std::move(fields),
5185                                                   /*interfaceBlock=*/true);
5186 
5187     // Create a global variable to contain this struct.
5188     Layout layout;
5189     layout.fBinding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
5190     layout.fSet     = fProgram.fConfig->fSettings.fDefaultUniformSet;
5191 
5192     fUniformBuffer.fInnerVariable = Variable::Make(/*pos=*/Position(),
5193                                                    /*modifiersPosition=*/Position(),
5194                                                    layout,
5195                                                    ModifierFlag::kUniform,
5196                                                    fUniformBuffer.fStruct.get(),
5197                                                    kUniformBufferName,
5198                                                    /*mangledName=*/"",
5199                                                    /*builtin=*/false,
5200                                                    Variable::Storage::kGlobal);
5201 
5202     // Create an interface block object for this global variable.
5203     fUniformBuffer.fInterfaceBlock =
5204             std::make_unique<InterfaceBlock>(Position(), fUniformBuffer.fInnerVariable.get());
5205 
5206     // Generate an interface block and hold onto its ID.
5207     fUniformBufferId = this->writeInterfaceBlock(*fUniformBuffer.fInterfaceBlock);
5208 }
5209 
addRTFlipUniform(Position pos)5210 void SPIRVCodeGenerator::addRTFlipUniform(Position pos) {
5211     SkASSERT(!fProgram.fConfig->fSettings.fForceNoRTFlip);
5212 
5213     if (fWroteRTFlip) {
5214         return;
5215     }
5216     // Flip variable hasn't been written yet. This means we don't have an existing
5217     // interface block, so we're free to just synthesize one.
5218     fWroteRTFlip = true;
5219     TArray<Field> fields;
5220     if (fProgram.fConfig->fSettings.fRTFlipOffset < 0) {
5221         fContext.fErrors->error(pos, "RTFlipOffset is negative");
5222     }
5223     fields.emplace_back(pos,
5224                         Layout(LayoutFlag::kNone,
5225                                /*location=*/-1,
5226                                fProgram.fConfig->fSettings.fRTFlipOffset,
5227                                /*binding=*/-1,
5228                                /*index=*/-1,
5229                                /*set=*/-1,
5230                                /*builtin=*/-1,
5231                                /*inputAttachmentIndex=*/-1),
5232                         ModifierFlag::kNone,
5233                         SKSL_RTFLIP_NAME,
5234                         fContext.fTypes.fFloat2.get());
5235     std::string_view name = "sksl_synthetic_uniforms";
5236     const Type* intfStruct = fSynthetics.takeOwnershipOfSymbol(Type::MakeStructType(
5237             fContext, Position(), name, std::move(fields), /*interfaceBlock=*/true));
5238     bool usePushConstants = fProgram.fConfig->fSettings.fUsePushConstants;
5239     int binding = -1, set = -1;
5240     if (!usePushConstants) {
5241         binding = fProgram.fConfig->fSettings.fRTFlipBinding;
5242         if (binding == -1) {
5243             fContext.fErrors->error(pos, "layout(binding=...) is required in SPIR-V");
5244         }
5245         set = fProgram.fConfig->fSettings.fRTFlipSet;
5246         if (set == -1) {
5247             fContext.fErrors->error(pos, "layout(set=...) is required in SPIR-V");
5248         }
5249     }
5250     Layout layout(/*flags=*/usePushConstants ? LayoutFlag::kPushConstant : LayoutFlag::kNone,
5251                   /*location=*/-1,
5252                   /*offset=*/-1,
5253                   binding,
5254                   /*index=*/-1,
5255                   set,
5256                   /*builtin=*/-1,
5257                   /*inputAttachmentIndex=*/-1);
5258     Variable* intfVar =
5259             fSynthetics.takeOwnershipOfSymbol(Variable::Make(/*pos=*/Position(),
5260                                                              /*modifiersPosition=*/Position(),
5261                                                              layout,
5262                                                              ModifierFlag::kUniform,
5263                                                              intfStruct,
5264                                                              name,
5265                                                              /*mangledName=*/"",
5266                                                              /*builtin=*/false,
5267                                                              Variable::Storage::kGlobal));
5268     {
5269         AutoAttachPoolToThread attach(fProgram.fPool.get());
5270         fProgram.fSymbols->add(fContext,
5271                                std::make_unique<FieldSymbol>(Position(), intfVar, /*field=*/0));
5272     }
5273     InterfaceBlock intf(Position(), intfVar);
5274     this->writeInterfaceBlock(intf, false);
5275 }
5276 
synthesizeTextureAndSampler(const Variable & combinedSampler)5277 std::tuple<const Variable*, const Variable*> SPIRVCodeGenerator::synthesizeTextureAndSampler(
5278         const Variable& combinedSampler) {
5279     SkASSERT(fUseTextureSamplerPairs);
5280     SkASSERT(combinedSampler.type().typeKind() == Type::TypeKind::kSampler);
5281 
5282     if (std::unique_ptr<SynthesizedTextureSamplerPair>* existingData =
5283             fSynthesizedSamplerMap.find(&combinedSampler)) {
5284         return {(*existingData)->fTexture.get(), (*existingData)->fSampler.get()};
5285     }
5286 
5287     auto data = std::make_unique<SynthesizedTextureSamplerPair>();
5288 
5289     Layout texLayout = combinedSampler.layout();
5290     texLayout.fBinding = texLayout.fTexture;
5291     data->fTextureName = std::string(combinedSampler.name()) + "_texture";
5292 
5293     auto texture = Variable::Make(/*pos=*/Position(),
5294                                   /*modifiersPosition=*/Position(),
5295                                   texLayout,
5296                                   combinedSampler.modifierFlags(),
5297                                   &combinedSampler.type().textureType(),
5298                                   data->fTextureName,
5299                                   /*mangledName=*/"",
5300                                   /*builtin=*/false,
5301                                   Variable::Storage::kGlobal);
5302 
5303     Layout samplerLayout = combinedSampler.layout();
5304     samplerLayout.fBinding = samplerLayout.fSampler;
5305     samplerLayout.fFlags &= ~LayoutFlag::kAllPixelFormats;
5306     data->fSamplerName = std::string(combinedSampler.name()) + "_sampler";
5307 
5308     auto sampler = Variable::Make(/*pos=*/Position(),
5309                                   /*modifiersPosition=*/Position(),
5310                                   samplerLayout,
5311                                   combinedSampler.modifierFlags(),
5312                                   fContext.fTypes.fSampler.get(),
5313                                   data->fSamplerName,
5314                                   /*mangledName=*/"",
5315                                   /*builtin=*/false,
5316                                   Variable::Storage::kGlobal);
5317 
5318     const Variable* t = texture.get();
5319     const Variable* s = sampler.get();
5320     data->fTexture = std::move(texture);
5321     data->fSampler = std::move(sampler);
5322     fSynthesizedSamplerMap.set(&combinedSampler, std::move(data));
5323 
5324     return {t, s};
5325 }
5326 
writeInstructions(const Program & program,OutputStream & out)5327 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
5328     Analysis::FindFunctionsToSpecialize(program, &fSpecializationInfo, [](const Variable& param) {
5329         return param.type().isSampler() || param.type().isUnsizedArray();
5330     });
5331 
5332     fGLSLExtendedInstructions = this->nextId(nullptr);
5333     StringStream body;
5334 
5335     // Do an initial pass over the program elements to establish some baseline info.
5336     const FunctionDeclaration* main = nullptr;
5337     int localSizeX = 1, localSizeY = 1, localSizeZ = 1;
5338     Position combinedSamplerPos;
5339     Position separateSamplerPos;
5340     for (const ProgramElement* e : program.elements()) {
5341         switch (e->kind()) {
5342             case ProgramElement::Kind::kFunction: {
5343                 // Assign SpvIds to functions.
5344                 const FunctionDefinition& funcDef = e->as<FunctionDefinition>();
5345                 const FunctionDeclaration& funcDecl = funcDef.declaration();
5346                 if (const Analysis::Specializations* specializations =
5347                             fSpecializationInfo.fSpecializationMap.find(&funcDecl)) {
5348                     for (int i = 0; i < specializations->size(); i++) {
5349                         fFunctionMap.set({&funcDecl, i}, this->nextId(nullptr));
5350                     }
5351                 } else {
5352                     fFunctionMap.set({&funcDecl, Analysis::kUnspecialized}, this->nextId(nullptr));
5353                 }
5354                 if (funcDecl.isMain()) {
5355                     main = &funcDecl;
5356                 }
5357                 break;
5358             }
5359             case ProgramElement::Kind::kGlobalVar: {
5360                 // Look for sampler variables and determine whether or not this program uses
5361                 // combined samplers or separate samplers. The layout backend will be marked as
5362                 // WebGPU for separate samplers, or Vulkan for combined samplers.
5363                 const GlobalVarDeclaration& decl = e->as<GlobalVarDeclaration>();
5364                 const Variable& var = *decl.varDeclaration().var();
5365                 if (var.type().isSampler()) {
5366                     if (var.layout().fFlags & LayoutFlag::kVulkan) {
5367                         combinedSamplerPos = decl.position();
5368                     }
5369                     if (var.layout().fFlags & (LayoutFlag::kWebGPU | LayoutFlag::kDirect3D)) {
5370                         separateSamplerPos = decl.position();
5371                     }
5372                 }
5373                 break;
5374             }
5375             case ProgramElement::Kind::kModifiers: {
5376                 // If this is a compute program, collect the local-size values. Dimensions that are
5377                 // not present will be assigned a value of 1.
5378                 if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5379                     const ModifiersDeclaration& modifiers = e->as<ModifiersDeclaration>();
5380                     if (modifiers.layout().fLocalSizeX >= 0) {
5381                         localSizeX = modifiers.layout().fLocalSizeX;
5382                     }
5383                     if (modifiers.layout().fLocalSizeY >= 0) {
5384                         localSizeY = modifiers.layout().fLocalSizeY;
5385                     }
5386                     if (modifiers.layout().fLocalSizeZ >= 0) {
5387                         localSizeZ = modifiers.layout().fLocalSizeZ;
5388                     }
5389                 }
5390                 break;
5391             }
5392             default:
5393                 break;
5394         }
5395     }
5396 
5397     // Make sure we have a main() function.
5398     if (!main) {
5399         fContext.fErrors->error(Position(), "program does not contain a main() function");
5400         return;
5401     }
5402     // Make sure our program's sampler usage is consistent.
5403     if (combinedSamplerPos.valid() && separateSamplerPos.valid()) {
5404         fContext.fErrors->error(Position(), "programs cannot contain a mixture of sampler types");
5405         fContext.fErrors->error(combinedSamplerPos, "combined sampler found here:");
5406         fContext.fErrors->error(separateSamplerPos, "separate sampler found here:");
5407         return;
5408     }
5409     fUseTextureSamplerPairs = separateSamplerPos.valid();
5410 
5411     // Emit interface blocks.
5412     std::set<SpvId> interfaceVars;
5413     for (const ProgramElement* e : program.elements()) {
5414         if (e->is<InterfaceBlock>()) {
5415             const InterfaceBlock& intf = e->as<InterfaceBlock>();
5416             SpvId id = this->writeInterfaceBlock(intf);
5417 
5418             if ((intf.var()->modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut)) &&
5419                 intf.var()->layout().fBuiltin == -1) {
5420                 interfaceVars.insert(id);
5421             }
5422         }
5423     }
5424     // If MustDeclareFragmentFrontFacing is set, the front-facing flag (sk_Clockwise) needs to be
5425     // explicitly declared in the output, whether or not the program explicitly references it.
5426     // However, if the program naturally declares it, we don't want to include it a second time;
5427     // we keep track of the real global variable declarations to see if sk_Clockwise is emitted.
5428     const VarDeclaration* missingClockwiseDecl = nullptr;
5429     if (fCaps.fMustDeclareFragmentFrontFacing) {
5430         if (const Symbol* clockwise = program.fSymbols->findBuiltinSymbol("sk_Clockwise")) {
5431             missingClockwiseDecl = clockwise->as<Variable>().varDeclaration();
5432         }
5433     }
5434     // Emit global variable declarations.
5435     for (const ProgramElement* e : program.elements()) {
5436         if (e->is<GlobalVarDeclaration>()) {
5437             const VarDeclaration& decl = e->as<GlobalVarDeclaration>().varDeclaration();
5438             if (!this->writeGlobalVarDeclaration(program.fConfig->fKind, decl)) {
5439                 return;
5440             }
5441             if (missingClockwiseDecl == &decl) {
5442                 // We emitted an sk_Clockwise declaration naturally, so we don't need a workaround.
5443                 missingClockwiseDecl = nullptr;
5444             }
5445         }
5446     }
5447     // All the global variables have been declared. If sk_Clockwise was not naturally included in
5448     // the output, but MustDeclareFragmentFrontFacing was set, we need to bodge it in ourselves.
5449     if (missingClockwiseDecl) {
5450         if (!this->writeGlobalVarDeclaration(program.fConfig->fKind, *missingClockwiseDecl)) {
5451             return;
5452         }
5453         missingClockwiseDecl = nullptr;
5454     }
5455     // Emit top-level uniforms into a dedicated uniform buffer.
5456     if (!fTopLevelUniforms.empty()) {
5457         this->writeUniformBuffer(get_top_level_symbol_table(*main));
5458     }
5459     // If main() returns a half4, synthesize a tiny entrypoint function which invokes the real
5460     // main() and stores the result into sk_FragColor.
5461     EntrypointAdapter adapter;
5462     if (main->returnType().matches(*fContext.fTypes.fHalf4)) {
5463         adapter = this->writeEntrypointAdapter(*main);
5464         if (adapter.entrypointDecl) {
5465             fFunctionMap.set({adapter.entrypointDecl.get(), Analysis::kUnspecialized},
5466                              this->nextId(nullptr));
5467             this->writeFunction(*adapter.entrypointDef, body);
5468             main = adapter.entrypointDecl.get();
5469         }
5470     }
5471     // Emit all the functions.
5472     for (const ProgramElement* e : program.elements()) {
5473         if (e->is<FunctionDefinition>()) {
5474             this->writeFunction(e->as<FunctionDefinition>(), body);
5475         }
5476     }
5477     // Add global in/out variables to the list of interface variables.
5478     for (const auto& [var, spvId] : fVariableMap) {
5479         if (var->storage() == Variable::Storage::kGlobal &&
5480             (var->modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut))) {
5481             interfaceVars.insert(spvId);
5482         }
5483     }
5484     this->writeCapabilities(out);
5485     this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
5486     this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
5487     this->writeOpCode(SpvOpEntryPoint,
5488                       (SpvId)(3 + (main->name().length() + 4) / 4) + (int32_t)interfaceVars.size(),
5489                       out);
5490     if (ProgramConfig::IsVertex(program.fConfig->fKind)) {
5491         this->writeWord(SpvExecutionModelVertex, out);
5492     } else if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
5493         this->writeWord(SpvExecutionModelFragment, out);
5494     } else if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5495         this->writeWord(SpvExecutionModelGLCompute, out);
5496     } else {
5497         SK_ABORT("cannot write this kind of program to SPIR-V\n");
5498     }
5499     const Analysis::SpecializedFunctionKey mainKey{main, Analysis::kUnspecialized};
5500     SpvId entryPoint = fFunctionMap[mainKey];
5501     this->writeWord(entryPoint, out);
5502     this->writeString(main->name(), out);
5503     for (int var : interfaceVars) {
5504         this->writeWord(var, out);
5505     }
5506     if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
5507         this->writeInstruction(SpvOpExecutionMode,
5508                                fFunctionMap[mainKey],
5509                                SpvExecutionModeOriginUpperLeft,
5510                                out);
5511     } else if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5512         this->writeInstruction(SpvOpExecutionMode,
5513                                fFunctionMap[mainKey],
5514                                SpvExecutionModeLocalSize,
5515                                localSizeX, localSizeY, localSizeZ,
5516                                out);
5517     }
5518     for (const ProgramElement* e : program.elements()) {
5519         if (e->is<Extension>()) {
5520             this->writeInstruction(SpvOpSourceExtension, e->as<Extension>().name(), out);
5521         }
5522     }
5523 
5524     write_stringstream(fNameBuffer, out);
5525     write_stringstream(fDecorationBuffer, out);
5526     write_stringstream(fConstantBuffer, out);
5527     write_stringstream(body, out);
5528 }
5529 
generateCode()5530 bool SPIRVCodeGenerator::generateCode() {
5531     SkASSERT(!fContext.fErrors->errorCount());
5532     this->writeWord(SpvMagicNumber, *fOut);
5533     this->writeWord(SpvVersion, *fOut);
5534     this->writeWord(SKSL_MAGIC, *fOut);
5535     StringStream buffer;
5536     this->writeInstructions(fProgram, buffer);
5537     this->writeWord(fIdCount, *fOut);
5538     this->writeWord(0, *fOut); // reserved, always zero
5539     write_stringstream(buffer, *fOut);
5540     return fContext.fErrors->errorCount() == 0;
5541 }
5542 
ToSPIRV(Program & program,const ShaderCaps * caps,OutputStream & out,ValidateSPIRVProc validateSPIRV)5543 bool ToSPIRV(Program& program,
5544              const ShaderCaps* caps,
5545              OutputStream& out,
5546              ValidateSPIRVProc validateSPIRV) {
5547     TRACE_EVENT0("skia.shaders", "SkSL::ToSPIRV");
5548     SkASSERT(caps != nullptr);
5549 
5550     program.fContext->fErrors->setSource(*program.fSource);
5551     bool result;
5552     if (validateSPIRV) {
5553         StringStream buffer;
5554         SPIRVCodeGenerator cg(program.fContext.get(), caps, &program, &buffer);
5555         result = cg.generateCode();
5556 
5557         if (result && program.fConfig->fSettings.fValidateSPIRV) {
5558             std::string_view binary = buffer.str();
5559             result = validateSPIRV(*program.fContext->fErrors, binary);
5560             out.write(binary.data(), binary.size());
5561         }
5562     } else {
5563         SPIRVCodeGenerator cg(program.fContext.get(), caps, &program, &out);
5564         result = cg.generateCode();
5565     }
5566     program.fContext->fErrors->setSource(std::string_view());
5567 
5568     return result;
5569 }
5570 
ToSPIRV(Program & program,const ShaderCaps * caps,std::string * out,ValidateSPIRVProc validateSPIRV)5571 bool ToSPIRV(Program& program,
5572              const ShaderCaps* caps,
5573              std::string* out,
5574              ValidateSPIRVProc validateSPIRV) {
5575     StringStream buffer;
5576     if (!ToSPIRV(program, caps, buffer, validateSPIRV)) {
5577         return false;
5578     }
5579     *out = buffer.str();
5580     return true;
5581 }
5582 
5583 }  // namespace SkSL
5584