1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/codegen/SkSLSPIRVCodeGenerator.h"
9
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/private/base/SkTo.h"
14 #include "src/base/SkEnumBitMask.h"
15 #include "src/core/SkChecksum.h"
16 #include "src/core/SkTHash.h"
17 #include "src/core/SkTraceEvent.h"
18 #include "src/sksl/GLSL.std.450.h"
19 #include "src/sksl/SkSLAnalysis.h"
20 #include "src/sksl/SkSLBuiltinTypes.h"
21 #include "src/sksl/SkSLCompiler.h"
22 #include "src/sksl/SkSLConstantFolder.h"
23 #include "src/sksl/SkSLContext.h"
24 #include "src/sksl/SkSLDefines.h"
25 #include "src/sksl/SkSLErrorReporter.h"
26 #include "src/sksl/SkSLIntrinsicList.h"
27 #include "src/sksl/SkSLMemoryLayout.h"
28 #include "src/sksl/SkSLOperator.h"
29 #include "src/sksl/SkSLOutputStream.h"
30 #include "src/sksl/SkSLPool.h"
31 #include "src/sksl/SkSLPosition.h"
32 #include "src/sksl/SkSLProgramSettings.h"
33 #include "src/sksl/SkSLStringStream.h"
34 #include "src/sksl/SkSLUtil.h"
35 #include "src/sksl/analysis/SkSLSpecialization.h"
36 #include "src/sksl/codegen/SkSLCodeGenerator.h"
37 #include "src/sksl/ir/SkSLBinaryExpression.h"
38 #include "src/sksl/ir/SkSLBlock.h"
39 #include "src/sksl/ir/SkSLConstructor.h"
40 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
41 #include "src/sksl/ir/SkSLConstructorCompound.h"
42 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
43 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
44 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
45 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
46 #include "src/sksl/ir/SkSLConstructorSplat.h"
47 #include "src/sksl/ir/SkSLDoStatement.h"
48 #include "src/sksl/ir/SkSLExpression.h"
49 #include "src/sksl/ir/SkSLExpressionStatement.h"
50 #include "src/sksl/ir/SkSLExtension.h"
51 #include "src/sksl/ir/SkSLFieldAccess.h"
52 #include "src/sksl/ir/SkSLFieldSymbol.h"
53 #include "src/sksl/ir/SkSLForStatement.h"
54 #include "src/sksl/ir/SkSLFunctionCall.h"
55 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
56 #include "src/sksl/ir/SkSLFunctionDefinition.h"
57 #include "src/sksl/ir/SkSLIRNode.h"
58 #include "src/sksl/ir/SkSLIfStatement.h"
59 #include "src/sksl/ir/SkSLIndexExpression.h"
60 #include "src/sksl/ir/SkSLInterfaceBlock.h"
61 #include "src/sksl/ir/SkSLLayout.h"
62 #include "src/sksl/ir/SkSLLiteral.h"
63 #include "src/sksl/ir/SkSLModifierFlags.h"
64 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
65 #include "src/sksl/ir/SkSLPoison.h"
66 #include "src/sksl/ir/SkSLPostfixExpression.h"
67 #include "src/sksl/ir/SkSLPrefixExpression.h"
68 #include "src/sksl/ir/SkSLProgram.h"
69 #include "src/sksl/ir/SkSLProgramElement.h"
70 #include "src/sksl/ir/SkSLReturnStatement.h"
71 #include "src/sksl/ir/SkSLSetting.h"
72 #include "src/sksl/ir/SkSLStatement.h"
73 #include "src/sksl/ir/SkSLSwitchCase.h"
74 #include "src/sksl/ir/SkSLSwitchStatement.h"
75 #include "src/sksl/ir/SkSLSwizzle.h"
76 #include "src/sksl/ir/SkSLSymbol.h"
77 #include "src/sksl/ir/SkSLSymbolTable.h"
78 #include "src/sksl/ir/SkSLTernaryExpression.h"
79 #include "src/sksl/ir/SkSLType.h"
80 #include "src/sksl/ir/SkSLVarDeclarations.h"
81 #include "src/sksl/ir/SkSLVariable.h"
82 #include "src/sksl/ir/SkSLVariableReference.h"
83 #include "src/sksl/spirv.h"
84 #include "src/sksl/transform/SkSLTransform.h"
85 #include "src/utils/SkBitSet.h"
86
87 #include <algorithm>
88 #include <cstdint>
89 #include <cstring>
90 #include <ctype.h>
91 #include <functional>
92 #include <memory>
93 #include <set>
94 #include <string>
95 #include <string_view>
96 #include <tuple>
97 #include <utility>
98 #include <vector>
99
100 using namespace skia_private;
101
102 #define kLast_Capability SpvCapabilityMultiViewport
103
104 constexpr int DEVICE_FRAGCOORDS_BUILTIN = -1000;
105 constexpr int DEVICE_CLOCKWISE_BUILTIN = -1001;
106 static constexpr SkSL::Layout kDefaultTypeLayout;
107
108 namespace SkSL {
109
110 enum class ProgramKind : int8_t;
111
112 enum class StorageClass {
113 kUniformConstant,
114 kInput,
115 kUniform,
116 kStorageBuffer,
117 kOutput,
118 kWorkgroup,
119 kCrossWorkgroup,
120 kPrivate,
121 kFunction,
122 kGeneric,
123 kPushConstant,
124 kAtomicCounter,
125 kImage,
126 };
127
get_storage_class_spv_id(StorageClass storageClass)128 static SpvStorageClass get_storage_class_spv_id(StorageClass storageClass) {
129 switch (storageClass) {
130 case StorageClass::kUniformConstant: return SpvStorageClassUniformConstant;
131 case StorageClass::kInput: return SpvStorageClassInput;
132 case StorageClass::kUniform: return SpvStorageClassUniform;
133 // Note: In SPIR-V 1.3, a storage buffer can be declared with the "StorageBuffer"
134 // storage class and the "Block" decoration and the <1.3 approach we use here ("Uniform"
135 // storage class and the "BufferBlock" decoration) is deprecated. Since we target SPIR-V
136 // 1.0, we have to use the deprecated approach which is well supported in Vulkan and
137 // addresses SkSL use cases (notably SkSL currently doesn't support pointer features that
138 // would benefit from SPV_KHR_variable_pointers capabilities).
139 case StorageClass::kStorageBuffer: return SpvStorageClassUniform;
140 case StorageClass::kOutput: return SpvStorageClassOutput;
141 case StorageClass::kWorkgroup: return SpvStorageClassWorkgroup;
142 case StorageClass::kCrossWorkgroup: return SpvStorageClassCrossWorkgroup;
143 case StorageClass::kPrivate: return SpvStorageClassPrivate;
144 case StorageClass::kFunction: return SpvStorageClassFunction;
145 case StorageClass::kGeneric: return SpvStorageClassGeneric;
146 case StorageClass::kPushConstant: return SpvStorageClassPushConstant;
147 case StorageClass::kAtomicCounter: return SpvStorageClassAtomicCounter;
148 case StorageClass::kImage: return SpvStorageClassImage;
149 }
150
151 SkUNREACHABLE;
152 }
153
154 class SPIRVCodeGenerator : public CodeGenerator {
155 public:
156 // We reserve an impossible SpvId as a sentinel. (NA meaning none, n/a, etc.)
157 static constexpr SpvId NA = (SpvId)-1;
158
159 class LValue {
160 public:
~LValue()161 virtual ~LValue() {}
162
163 // returns a pointer to the lvalue, if possible. If the lvalue cannot be directly referenced
164 // by a pointer (e.g. vector swizzles), returns NA.
getPointer()165 virtual SpvId getPointer() { return NA; }
166
167 // Returns true if a valid pointer returned by getPointer represents a memory object
168 // (see https://github.com/KhronosGroup/SPIRV-Tools/issues/2892). Has no meaning if
169 // getPointer() returns NA.
isMemoryObjectPointer() const170 virtual bool isMemoryObjectPointer() const { return true; }
171
172 // Applies a swizzle to the components of the LValue, if possible. This is used to create
173 // LValues that are swizzes-of-swizzles. Non-swizzle LValues can just return false.
applySwizzle(const ComponentArray & components,const Type & newType)174 virtual bool applySwizzle(const ComponentArray& components, const Type& newType) {
175 return false;
176 }
177
178 // Returns the storage class of the lvalue.
179 virtual StorageClass storageClass() const = 0;
180
181 virtual SpvId load(OutputStream& out) = 0;
182
183 virtual void store(SpvId value, OutputStream& out) = 0;
184 };
185
SPIRVCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out)186 SPIRVCodeGenerator(const Context* context,
187 const ShaderCaps* caps,
188 const Program* program,
189 OutputStream* out)
190 : CodeGenerator(context, caps, program, out) {}
191
192 bool generateCode() override;
193
194 private:
195 enum IntrinsicOpcodeKind {
196 kGLSL_STD_450_IntrinsicOpcodeKind,
197 kSPIRV_IntrinsicOpcodeKind,
198 kSpecial_IntrinsicOpcodeKind,
199 kInvalid_IntrinsicOpcodeKind,
200 };
201
202 enum SpecialIntrinsic {
203 kAtan_SpecialIntrinsic,
204 kClamp_SpecialIntrinsic,
205 kMatrixCompMult_SpecialIntrinsic,
206 kMax_SpecialIntrinsic,
207 kMin_SpecialIntrinsic,
208 kMix_SpecialIntrinsic,
209 kMod_SpecialIntrinsic,
210 kDFdy_SpecialIntrinsic,
211 kSaturate_SpecialIntrinsic,
212 kSampledImage_SpecialIntrinsic,
213 kSmoothStep_SpecialIntrinsic,
214 kStep_SpecialIntrinsic,
215 kSubpassLoad_SpecialIntrinsic,
216 kTexture_SpecialIntrinsic,
217 kTextureGrad_SpecialIntrinsic,
218 kTextureLod_SpecialIntrinsic,
219 kTextureRead_SpecialIntrinsic,
220 kTextureWrite_SpecialIntrinsic,
221 kTextureWidth_SpecialIntrinsic,
222 kTextureHeight_SpecialIntrinsic,
223 kAtomicAdd_SpecialIntrinsic,
224 kAtomicLoad_SpecialIntrinsic,
225 kAtomicStore_SpecialIntrinsic,
226 kStorageBarrier_SpecialIntrinsic,
227 kWorkgroupBarrier_SpecialIntrinsic,
228 };
229
230 enum class Precision {
231 kDefault,
232 kRelaxed,
233 };
234
235 struct TempVar {
236 SpvId spvId;
237 const Type* type;
238 std::unique_ptr<SPIRVCodeGenerator::LValue> lvalue;
239 };
240
241 /**
242 * Pass in the type to automatically add a RelaxedPrecision decoration for the id when
243 * appropriate, or null to never add one.
244 */
245 SpvId nextId(const Type* type);
246
247 SpvId nextId(Precision precision);
248
249 SpvId getType(const Type& type);
250
251 SpvId getType(const Type& type, const Layout& typeLayout, const MemoryLayout& memoryLayout);
252
253 SpvId getFunctionType(const FunctionDeclaration& function);
254
255 SpvId getFunctionParameterType(const Type& parameterType, const Layout& parameterLayout);
256
257 SpvId getPointerType(const Type& type, StorageClass storageClass);
258
259 SpvId getPointerType(const Type& type,
260 const Layout& typeLayout,
261 const MemoryLayout& memoryLayout,
262 StorageClass storageClass);
263
264 StorageClass getStorageClass(const Expression& expr);
265
266 TArray<SpvId> getAccessChain(const Expression& expr, OutputStream& out);
267
268 void writeLayout(const Layout& layout, SpvId target, Position pos);
269
270 void writeFieldLayout(const Layout& layout, SpvId target, int member);
271
272 SpvId writeStruct(const Type& type, const MemoryLayout& memoryLayout);
273
274 void writeProgramElement(const ProgramElement& pe, OutputStream& out);
275
276 SpvId writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip = true);
277
278 void writeFunctionStart(const FunctionDeclaration& f, OutputStream& out);
279
280 SpvId writeFunctionDeclaration(const FunctionDeclaration& f, OutputStream& out);
281
282 void writeFunction(const FunctionDefinition& f, OutputStream& out);
283
284 // Writes the function with the defined specializationIndex, if the index is -1, then it is
285 // assumed that the function has no specializations.
286 void writeFunctionInstantiation(const FunctionDefinition& f,
287 Analysis::SpecializationIndex specializationIndex,
288 const Analysis::SpecializedParameters* specializedParams,
289 OutputStream& out);
290
291 bool writeGlobalVarDeclaration(ProgramKind kind, const VarDeclaration& v);
292
293 SpvId writeGlobalVar(ProgramKind kind, StorageClass, const Variable& v);
294
295 void writeVarDeclaration(const VarDeclaration& var, OutputStream& out);
296
297 SpvId writeVariableReference(const VariableReference& ref, OutputStream& out);
298
299 int findUniformFieldIndex(const Variable& var) const;
300
301 std::unique_ptr<LValue> getLValue(const Expression& value, OutputStream& out);
302
303 SpvId writeExpression(const Expression& expr, OutputStream& out);
304
305 SpvId writeIntrinsicCall(const FunctionCall& c, OutputStream& out);
306
307 void writeFunctionCallArgument(TArray<SpvId>& argumentList,
308 const FunctionCall& call,
309 int argIndex,
310 std::vector<TempVar>* tempVars,
311 const SkBitSet* specializedParams,
312 OutputStream& out);
313
314 void copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out);
315
316 SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out);
317
318
319 void writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
320 SpvId signedInst, SpvId unsignedInst,
321 const TArray<SpvId>& args, OutputStream& out);
322
323 /**
324 * Promotes an expression to a vector. If the expression is already a vector with vectorSize
325 * columns, returns it unmodified. If the expression is a scalar, either promotes it to a
326 * vector (if vectorSize > 1) or returns it unmodified (if vectorSize == 1). Asserts if the
327 * expression is already a vector and it does not have vectorSize columns.
328 */
329 SpvId vectorize(const Expression& expr, int vectorSize, OutputStream& out);
330
331 /**
332 * Given a list of potentially mixed scalars and vectors, promotes the scalars to match the
333 * size of the vectors and returns the ids of the written expressions. e.g. given (float, vec2),
334 * returns (vec2(float), vec2). It is an error to use mismatched vector sizes, e.g. (float,
335 * vec2, vec3).
336 */
337 TArray<SpvId> vectorize(const ExpressionArray& args, OutputStream& out);
338
339 /**
340 * Given a SpvId of a scalar, splats it across the passed-in type (scalar, vector or matrix) and
341 * returns the SpvId of the new value.
342 */
343 SpvId splat(const Type& type, SpvId id, OutputStream& out);
344
345 SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out);
346 SpvId writeAtomicIntrinsic(const FunctionCall& c,
347 SpecialIntrinsic kind,
348 SpvId resultId,
349 OutputStream& out);
350
351 SpvId castScalarToFloat(SpvId inputId, const Type& inputType, const Type& outputType,
352 OutputStream& out);
353
354 SpvId castScalarToSignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
355 OutputStream& out);
356
357 SpvId castScalarToUnsignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
358 OutputStream& out);
359
360 SpvId castScalarToBoolean(SpvId inputId, const Type& inputType, const Type& outputType,
361 OutputStream& out);
362
363 SpvId castScalarToType(SpvId inputExprId, const Type& inputType, const Type& outputType,
364 OutputStream& out);
365
366 /**
367 * Writes a potentially-different-sized copy of a matrix. Entries which do not exist in the
368 * source matrix are filled with zero; entries which do not exist in the destination matrix are
369 * ignored.
370 */
371 SpvId writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType, OutputStream& out);
372
373 void addColumnEntry(const Type& columnType,
374 TArray<SpvId>* currentColumn,
375 TArray<SpvId>* columnIds,
376 int rows,
377 SpvId entry,
378 OutputStream& out);
379
380 SpvId writeConstructorCompound(const ConstructorCompound& c, OutputStream& out);
381
382 SpvId writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out);
383
384 SpvId writeVectorConstructor(const ConstructorCompound& c, OutputStream& out);
385
386 SpvId writeCompositeConstructor(const AnyConstructor& c, OutputStream& out);
387
388 SpvId writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, OutputStream& out);
389
390 SpvId writeConstructorMatrixResize(const ConstructorMatrixResize& c, OutputStream& out);
391
392 SpvId writeConstructorScalarCast(const ConstructorScalarCast& c, OutputStream& out);
393
394 SpvId writeConstructorSplat(const ConstructorSplat& c, OutputStream& out);
395
396 SpvId writeConstructorCompoundCast(const ConstructorCompoundCast& c, OutputStream& out);
397
398 SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
399
400 SpvId writeSwizzle(const Expression& baseExpr,
401 const ComponentArray& components,
402 OutputStream& out);
403
404 SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
405
406 /**
407 * Folds the potentially-vector result of a logical operation down to a single bool. If
408 * operandType is a vector type, assumes that the intermediate result in id is a bvec of the
409 * same dimensions, and applys all() to it to fold it down to a single bool value. Otherwise,
410 * returns the original id value.
411 */
412 SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out);
413
414 SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator,
415 SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
416 SpvOp_ mergeOperator, OutputStream& out);
417
418 SpvId writeStructComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
419 OutputStream& out);
420
421 SpvId writeArrayComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
422 OutputStream& out);
423
424 // Used by writeStructComparison and writeArrayComparison to logically combine field-by-field
425 // comparisons into an overall comparison result.
426 // - `a.x == b.x` merged with `a.y == b.y` generates `(a.x == b.x) && (a.y == b.y)`
427 // - `a.x != b.x` merged with `a.y != b.y` generates `(a.x != b.x) || (a.y != b.y)`
428 SpvId mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op, OutputStream& out);
429
430 // When the RewriteMatrixVectorMultiply caps bit is set, we manually decompose the M*V
431 // multiplication into a sum of vector-scalar products.
432 SpvId writeDecomposedMatrixVectorMultiply(const Type& leftType,
433 SpvId lhs,
434 const Type& rightType,
435 SpvId rhs,
436 const Type& resultType,
437 OutputStream& out);
438
439 SpvId writeComponentwiseMatrixUnary(const Type& operandType,
440 SpvId operand,
441 SpvOp_ op,
442 OutputStream& out);
443
444 SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
445 SpvOp_ op, OutputStream& out);
446
447 SpvId writeBinaryOperationComponentwiseIfMatrix(const Type& resultType, const Type& operandType,
448 SpvId lhs, SpvId rhs,
449 SpvOp_ ifFloat, SpvOp_ ifInt,
450 SpvOp_ ifUInt, SpvOp_ ifBool,
451 OutputStream& out);
452
453 SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
454 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
455 SpvOp_ ifBool, OutputStream& out);
456
457 SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
458 SpvId rhs, bool writeComponentwiseIfMatrix, SpvOp_ ifFloat,
459 SpvOp_ ifInt, SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out);
460
461 SpvId writeReciprocal(const Type& type, SpvId value, OutputStream& out);
462
463 SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
464 const Type& rightType, SpvId rhs, const Type& resultType,
465 OutputStream& out);
466
467 SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
468
469 SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);
470
471 SpvId writeIndexExpression(const IndexExpression& expr, OutputStream& out);
472
473 SpvId writeLogicalAnd(const Expression& left, const Expression& right, OutputStream& out);
474
475 SpvId writeLogicalOr(const Expression& left, const Expression& right, OutputStream& out);
476
477 SpvId writePrefixExpression(const PrefixExpression& p, OutputStream& out);
478
479 SpvId writePostfixExpression(const PostfixExpression& p, OutputStream& out);
480
481 SpvId writeLiteral(const Literal& f);
482
483 SpvId writeLiteral(double value, const Type& type);
484
485 void writeStatement(const Statement& s, OutputStream& out);
486
487 void writeBlock(const Block& b, OutputStream& out);
488
489 void writeIfStatement(const IfStatement& stmt, OutputStream& out);
490
491 void writeForStatement(const ForStatement& f, OutputStream& out);
492
493 void writeDoStatement(const DoStatement& d, OutputStream& out);
494
495 void writeSwitchStatement(const SwitchStatement& s, OutputStream& out);
496
497 void writeReturnStatement(const ReturnStatement& r, OutputStream& out);
498
499 void writeCapabilities(OutputStream& out);
500
501 void writeInstructions(const Program& program, OutputStream& out);
502
503 void writeOpCode(SpvOp_ opCode, int length, OutputStream& out);
504
505 void writeWord(int32_t word, OutputStream& out);
506
507 void writeString(std::string_view s, OutputStream& out);
508
509 void writeInstruction(SpvOp_ opCode, OutputStream& out);
510
511 void writeInstruction(SpvOp_ opCode, std::string_view string, OutputStream& out);
512
513 void writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out);
514
515 void writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
516 OutputStream& out);
517
518 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, std::string_view string,
519 OutputStream& out);
520
521 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, OutputStream& out);
522
523 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3,
524 OutputStream& out);
525
526 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
527 OutputStream& out);
528
529 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
530 int32_t word5, OutputStream& out);
531
532 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
533 int32_t word5, int32_t word6, OutputStream& out);
534
535 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
536 int32_t word5, int32_t word6, int32_t word7, OutputStream& out);
537
538 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
539 int32_t word5, int32_t word6, int32_t word7, int32_t word8,
540 OutputStream& out);
541
542 // This form of writeInstruction can deduplicate redundant ops.
543 struct Word;
544 // 8 Words is enough for nearly all instructions (except variable-length instructions like
545 // OpAccessChain or OpConstantComposite).
546 using Words = STArray<8, Word, true>;
547 SpvId writeInstruction(SpvOp_ opCode, const TArray<Word, true>& words, OutputStream& out);
548
549 struct Instruction {
550 SpvId fOp;
551 int32_t fResultKind;
552 STArray<8, int32_t> fWords;
553
554 bool operator==(const Instruction& that) const;
555 struct Hash;
556 };
557
558 static Instruction BuildInstructionKey(SpvOp_ opCode, const TArray<Word, true>& words);
559
560 // The writeOpXxxxx calls will simplify and deduplicate ops where possible.
561 SpvId writeOpConstantTrue(const Type& type);
562 SpvId writeOpConstantFalse(const Type& type);
563 SpvId writeOpConstant(const Type& type, int32_t valueBits);
564 SpvId writeOpConstantComposite(const Type& type, const TArray<SpvId>& values);
565 SpvId writeOpCompositeConstruct(const Type& type, const TArray<SpvId>&, OutputStream& out);
566 SpvId writeOpCompositeExtract(const Type& type, SpvId base, int component, OutputStream& out);
567 SpvId writeOpCompositeExtract(const Type& type, SpvId base, int componentA, int componentB,
568 OutputStream& out);
569 SpvId writeOpLoad(SpvId type, Precision precision, SpvId pointer, OutputStream& out);
570 void writeOpStore(StorageClass storageClass, SpvId pointer, SpvId value, OutputStream& out);
571
572 // Converts the provided SpvId(s) into an array of scalar OpConstants, if it can be done.
573 bool toConstants(SpvId value, TArray<SpvId>* constants);
574 bool toConstants(SkSpan<const SpvId> values, TArray<SpvId>* constants);
575
576 // Extracts the requested component SpvId from a composite instruction, if it can be done.
577 Instruction* resultTypeForInstruction(const Instruction& instr);
578 int numComponentsForVecInstruction(const Instruction& instr);
579 SpvId toComponent(SpvId id, int component);
580
581 struct ConditionalOpCounts {
582 int numReachableOps;
583 int numStoreOps;
584 };
585 ConditionalOpCounts getConditionalOpCounts();
586 void pruneConditionalOps(ConditionalOpCounts ops);
587
588 enum StraightLineLabelType {
589 // Use "BranchlessBlock" for blocks which are never explicitly branched-to at all. This
590 // happens at the start of a function, or when we find unreachable code.
591 kBranchlessBlock,
592
593 // Use "BranchIsOnPreviousLine" when writing a label that comes immediately after its
594 // associated branch. Example usage:
595 // - SPIR-V does not implicitly fall through from one block to the next, so you may need to
596 // use an OpBranch to explicitly jump to the next block, even when they are adjacent in
597 // the code.
598 // - The block immediately following an OpBranchConditional or OpSwitch.
599 kBranchIsOnPreviousLine,
600 };
601
602 enum BranchingLabelType {
603 // Use "BranchIsAbove" for labels which are referenced by OpBranch or OpBranchConditional
604 // ops that are above the label in the code--i.e., the branch skips forward in the code.
605 kBranchIsAbove,
606
607 // Use "BranchIsBelow" for labels which are referenced by OpBranch or OpBranchConditional
608 // ops below the label in the code--i.e., the branch jumps backward in the code.
609 kBranchIsBelow,
610
611 // Use "BranchesOnBothSides" for labels which have branches coming from both directions.
612 kBranchesOnBothSides,
613 };
614 void writeLabel(SpvId label, StraightLineLabelType type, OutputStream& out);
615 void writeLabel(SpvId label, BranchingLabelType type, ConditionalOpCounts ops,
616 OutputStream& out);
617
618 MemoryLayout memoryLayoutForStorageClass(StorageClass storageClass);
619 MemoryLayout memoryLayoutForVariable(const Variable&) const;
620
621 struct EntrypointAdapter {
622 std::unique_ptr<FunctionDefinition> entrypointDef;
623 std::unique_ptr<FunctionDeclaration> entrypointDecl;
624 };
625
626 EntrypointAdapter writeEntrypointAdapter(const FunctionDeclaration& main);
627
628 struct UniformBuffer {
629 std::unique_ptr<InterfaceBlock> fInterfaceBlock;
630 std::unique_ptr<Variable> fInnerVariable;
631 std::unique_ptr<Type> fStruct;
632 };
633
634 void writeUniformBuffer(SymbolTable* topLevelSymbolTable);
635
636 void addRTFlipUniform(Position pos);
637
638 std::unique_ptr<Expression> identifier(std::string_view name);
639
640 std::tuple<const Variable*, const Variable*> synthesizeTextureAndSampler(
641 const Variable& combinedSampler);
642
643 const MemoryLayout fDefaultMemoryLayout{MemoryLayout::Standard::k140};
644
645 uint64_t fCapabilities = 0;
646 SpvId fIdCount = 1;
647 SpvId fGLSLExtendedInstructions;
648 struct Intrinsic {
649 IntrinsicOpcodeKind opKind;
650 int32_t floatOp;
651 int32_t signedOp;
652 int32_t unsignedOp;
653 int32_t boolOp;
654 };
655 Intrinsic getIntrinsic(IntrinsicKind) const;
656
657 THashMap<Analysis::SpecializedFunctionKey, SpvId, Analysis::SpecializedFunctionKey::Hash>
658 fFunctionMap;
659
660 Analysis::SpecializationInfo fSpecializationInfo;
661 Analysis::SpecializationIndex fActiveSpecializationIndex = Analysis::kUnspecialized;
662 const Analysis::SpecializedParameters* fActiveSpecialization = nullptr;
663
664 THashMap<const Variable*, SpvId> fVariableMap;
665 THashMap<const Type*, SpvId> fStructMap;
666 StringStream fGlobalInitializersBuffer;
667 StringStream fConstantBuffer;
668 StringStream fVariableBuffer;
669 StringStream fNameBuffer;
670 StringStream fDecorationBuffer;
671
672 // Mapping from combined sampler declarations to synthesized texture/sampler variables.
673 // This is used when the sampler is declared as `layout(webgpu)` or `layout(direct3d)`.
674 bool fUseTextureSamplerPairs = false;
675 struct SynthesizedTextureSamplerPair {
676 // The names of the synthesized variables. The Variable objects themselves store string
677 // views referencing these strings. It is important for the std::string instances to have a
678 // fixed memory location after the string views get created, which is why
679 // `fSynthesizedSamplerMap` stores unique_ptr instead of values.
680 std::string fTextureName;
681 std::string fSamplerName;
682 std::unique_ptr<Variable> fTexture;
683 std::unique_ptr<Variable> fSampler;
684 };
685 THashMap<const Variable*, std::unique_ptr<SynthesizedTextureSamplerPair>>
686 fSynthesizedSamplerMap;
687
688 // These caches map SpvIds to Instructions, and vice-versa. This enables us to deduplicate code
689 // (by detecting an Instruction we've already issued and reusing the SpvId), and to introspect
690 // and simplify code we've already emitted (by taking a SpvId from an Instruction and following
691 // it back to its source).
692
693 // A map of instruction -> SpvId:
694 THashMap<Instruction, SpvId, Instruction::Hash> fOpCache;
695 // A map of SpvId -> instruction:
696 THashMap<SpvId, Instruction> fSpvIdCache;
697 // A map of SpvId -> value SpvId:
698 THashMap<SpvId, SpvId> fStoreCache;
699
700 // "Reachable" ops are instructions which can safely be accessed from the current block.
701 // For instance, if our SPIR-V contains `%3 = OpFAdd %1 %2`, we would be able to access and
702 // reuse that computation on following lines. However, if that Add operation occurred inside an
703 // `if` block, then its SpvId becomes inaccessible once we complete the if statement (since
704 // depending on the if condition, we may or may not have actually done that computation). The
705 // same logic applies to other control-flow blocks as well. Once an instruction becomes
706 // unreachable, we remove it from both op-caches.
707 TArray<SpvId> fReachableOps;
708
709 // The "store-ops" list contains a running list of all the pointers in the store cache. If a
710 // store occurs inside of a conditional block, once that block exits, we no longer know what is
711 // stored in that particular SpvId. At that point, we must remove any associated entry from the
712 // store cache.
713 TArray<SpvId> fStoreOps;
714
715 // label of the current block, or 0 if we are not in a block
716 SpvId fCurrentBlock = 0;
717 TArray<SpvId> fBreakTarget;
718 TArray<SpvId> fContinueTarget;
719 bool fWroteRTFlip = false;
720 // holds variables synthesized during output, for lifetime purposes
721 SymbolTable fSynthetics{/*builtin=*/true};
722 // Holds a list of uniforms that were declared as globals at the top-level instead of in an
723 // interface block.
724 UniformBuffer fUniformBuffer;
725 std::vector<const VarDeclaration*> fTopLevelUniforms;
726 THashMap<const Variable*, int> fTopLevelUniformMap; // <var, UniformBuffer field index>
727 SpvId fUniformBufferId = NA;
728
729 friend class PointerLValue;
730 friend class SwizzleLValue;
731 };
732
733 // Equality and hash operators for Instructions.
operator ==(const SPIRVCodeGenerator::Instruction & that) const734 bool SPIRVCodeGenerator::Instruction::operator==(const SPIRVCodeGenerator::Instruction& that) const {
735 return fOp == that.fOp &&
736 fResultKind == that.fResultKind &&
737 fWords == that.fWords;
738 }
739
740 struct SPIRVCodeGenerator::Instruction::Hash {
operator ()SkSL::SPIRVCodeGenerator::Instruction::Hash741 uint32_t operator()(const SPIRVCodeGenerator::Instruction& key) const {
742 uint32_t hash = key.fResultKind;
743 hash = SkChecksum::Hash32(&key.fOp, sizeof(key.fOp), hash);
744 hash = SkChecksum::Hash32(key.fWords.data(), key.fWords.size() * sizeof(int32_t), hash);
745 return hash;
746 }
747 };
748
749 // This class is used to pass values and result placeholder slots to writeInstruction.
750 struct SPIRVCodeGenerator::Word {
751 enum Kind {
752 kNone, // intended for use as a sentinel, not part of any Instruction
753 kSpvId,
754 kNumber,
755 kDefaultPrecisionResult,
756 kRelaxedPrecisionResult,
757 kUniqueResult,
758 kKeyedResult,
759 };
760
WordSkSL::SPIRVCodeGenerator::Word761 Word(SpvId id) : fValue(id), fKind(Kind::kSpvId) {}
WordSkSL::SPIRVCodeGenerator::Word762 Word(int32_t val, Kind kind) : fValue(val), fKind(kind) {}
763
NumberSkSL::SPIRVCodeGenerator::Word764 static Word Number(int32_t val) {
765 return Word{val, Kind::kNumber};
766 }
767
ResultSkSL::SPIRVCodeGenerator::Word768 static Word Result(const Type& type) {
769 return (type.hasPrecision() && !type.highPrecision()) ? RelaxedResult() : Result();
770 }
771
RelaxedResultSkSL::SPIRVCodeGenerator::Word772 static Word RelaxedResult() {
773 return Word{(int32_t)NA, kRelaxedPrecisionResult};
774 }
775
UniqueResultSkSL::SPIRVCodeGenerator::Word776 static Word UniqueResult() {
777 return Word{(int32_t)NA, kUniqueResult};
778 }
779
ResultSkSL::SPIRVCodeGenerator::Word780 static Word Result() {
781 return Word{(int32_t)NA, kDefaultPrecisionResult};
782 }
783
784 // Unlike a Result (where the result ID is always deduplicated to its first instruction) or a
785 // UniqueResult (which always produces a new instruction), a KeyedResult allows an instruction
786 // to be deduplicated among those that share the same `key`.
KeyedResultSkSL::SPIRVCodeGenerator::Word787 static Word KeyedResult(int32_t key) { return Word{key, Kind::kKeyedResult}; }
788
isResultSkSL::SPIRVCodeGenerator::Word789 bool isResult() const { return fKind >= Kind::kDefaultPrecisionResult; }
790
791 int32_t fValue;
792 Kind fKind;
793 };
794
795 // Skia's magic number is 31 and goes in the top 16 bits. We can use the lower bits to version the
796 // sksl generator if we want.
797 // https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/spir-v.xml#L84
798 static const int32_t SKSL_MAGIC = 0x001F0000;
799
getIntrinsic(IntrinsicKind ik) const800 SPIRVCodeGenerator::Intrinsic SPIRVCodeGenerator::getIntrinsic(IntrinsicKind ik) const {
801
802 #define ALL_GLSL(x) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, GLSLstd450 ## x, \
803 GLSLstd450 ## x, GLSLstd450 ## x, GLSLstd450 ## x}
804 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, \
805 GLSLstd450 ## ifFloat, \
806 GLSLstd450 ## ifInt, \
807 GLSLstd450 ## ifUInt, \
808 SpvOpUndef}
809 #define ALL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
810 SpvOp ## x, SpvOp ## x, SpvOp ## x, SpvOp ## x}
811 #define BOOL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
812 SpvOpUndef, SpvOpUndef, SpvOpUndef, SpvOp ## x}
813 #define FLOAT_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
814 SpvOp ## x, SpvOpUndef, SpvOpUndef, SpvOpUndef}
815 #define SPECIAL(x) Intrinsic{kSpecial_IntrinsicOpcodeKind, k ## x ## _SpecialIntrinsic, \
816 k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
817 k ## x ## _SpecialIntrinsic}
818
819 switch (ik) {
820 case k_round_IntrinsicKind: return ALL_GLSL(Round);
821 case k_roundEven_IntrinsicKind: return ALL_GLSL(RoundEven);
822 case k_trunc_IntrinsicKind: return ALL_GLSL(Trunc);
823 case k_abs_IntrinsicKind: return BY_TYPE_GLSL(FAbs, SAbs, SAbs);
824 case k_sign_IntrinsicKind: return BY_TYPE_GLSL(FSign, SSign, SSign);
825 case k_floor_IntrinsicKind: return ALL_GLSL(Floor);
826 case k_ceil_IntrinsicKind: return ALL_GLSL(Ceil);
827 case k_fract_IntrinsicKind: return ALL_GLSL(Fract);
828 case k_radians_IntrinsicKind: return ALL_GLSL(Radians);
829 case k_degrees_IntrinsicKind: return ALL_GLSL(Degrees);
830 case k_sin_IntrinsicKind: return ALL_GLSL(Sin);
831 case k_cos_IntrinsicKind: return ALL_GLSL(Cos);
832 case k_tan_IntrinsicKind: return ALL_GLSL(Tan);
833 case k_asin_IntrinsicKind: return ALL_GLSL(Asin);
834 case k_acos_IntrinsicKind: return ALL_GLSL(Acos);
835 case k_atan_IntrinsicKind: return SPECIAL(Atan);
836 case k_sinh_IntrinsicKind: return ALL_GLSL(Sinh);
837 case k_cosh_IntrinsicKind: return ALL_GLSL(Cosh);
838 case k_tanh_IntrinsicKind: return ALL_GLSL(Tanh);
839 case k_asinh_IntrinsicKind: return ALL_GLSL(Asinh);
840 case k_acosh_IntrinsicKind: return ALL_GLSL(Acosh);
841 case k_atanh_IntrinsicKind: return ALL_GLSL(Atanh);
842 case k_pow_IntrinsicKind: return ALL_GLSL(Pow);
843 case k_exp_IntrinsicKind: return ALL_GLSL(Exp);
844 case k_log_IntrinsicKind: return ALL_GLSL(Log);
845 case k_exp2_IntrinsicKind: return ALL_GLSL(Exp2);
846 case k_log2_IntrinsicKind: return ALL_GLSL(Log2);
847 case k_sqrt_IntrinsicKind: return ALL_GLSL(Sqrt);
848 case k_inverse_IntrinsicKind: return ALL_GLSL(MatrixInverse);
849 case k_outerProduct_IntrinsicKind: return ALL_SPIRV(OuterProduct);
850 case k_transpose_IntrinsicKind: return ALL_SPIRV(Transpose);
851 case k_isinf_IntrinsicKind: return ALL_SPIRV(IsInf);
852 case k_isnan_IntrinsicKind: return ALL_SPIRV(IsNan);
853 case k_inversesqrt_IntrinsicKind: return ALL_GLSL(InverseSqrt);
854 case k_determinant_IntrinsicKind: return ALL_GLSL(Determinant);
855 case k_matrixCompMult_IntrinsicKind: return SPECIAL(MatrixCompMult);
856 case k_matrixInverse_IntrinsicKind: return ALL_GLSL(MatrixInverse);
857 case k_mod_IntrinsicKind: return SPECIAL(Mod);
858 case k_modf_IntrinsicKind: return ALL_GLSL(Modf);
859 case k_min_IntrinsicKind: return SPECIAL(Min);
860 case k_max_IntrinsicKind: return SPECIAL(Max);
861 case k_clamp_IntrinsicKind: return SPECIAL(Clamp);
862 case k_saturate_IntrinsicKind: return SPECIAL(Saturate);
863 case k_dot_IntrinsicKind: return FLOAT_SPIRV(Dot);
864 case k_mix_IntrinsicKind: return SPECIAL(Mix);
865 case k_step_IntrinsicKind: return SPECIAL(Step);
866 case k_smoothstep_IntrinsicKind: return SPECIAL(SmoothStep);
867 case k_fma_IntrinsicKind: return ALL_GLSL(Fma);
868 case k_frexp_IntrinsicKind: return ALL_GLSL(Frexp);
869 case k_ldexp_IntrinsicKind: return ALL_GLSL(Ldexp);
870
871 #define PACK(type) case k_pack##type##_IntrinsicKind: return ALL_GLSL(Pack##type); \
872 case k_unpack##type##_IntrinsicKind: return ALL_GLSL(Unpack##type)
873 PACK(Snorm4x8);
874 PACK(Unorm4x8);
875 PACK(Snorm2x16);
876 PACK(Unorm2x16);
877 PACK(Half2x16);
878 #undef PACK
879
880 case k_length_IntrinsicKind: return ALL_GLSL(Length);
881 case k_distance_IntrinsicKind: return ALL_GLSL(Distance);
882 case k_cross_IntrinsicKind: return ALL_GLSL(Cross);
883 case k_normalize_IntrinsicKind: return ALL_GLSL(Normalize);
884 case k_faceforward_IntrinsicKind: return ALL_GLSL(FaceForward);
885 case k_reflect_IntrinsicKind: return ALL_GLSL(Reflect);
886 case k_refract_IntrinsicKind: return ALL_GLSL(Refract);
887 case k_bitCount_IntrinsicKind: return ALL_SPIRV(BitCount);
888 case k_findLSB_IntrinsicKind: return ALL_GLSL(FindILsb);
889 case k_findMSB_IntrinsicKind: return BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
890 case k_dFdx_IntrinsicKind: return FLOAT_SPIRV(DPdx);
891 case k_dFdy_IntrinsicKind: return SPECIAL(DFdy);
892 case k_fwidth_IntrinsicKind: return FLOAT_SPIRV(Fwidth);
893
894 case k_sample_IntrinsicKind: return SPECIAL(Texture);
895 case k_sampleGrad_IntrinsicKind: return SPECIAL(TextureGrad);
896 case k_sampleLod_IntrinsicKind: return SPECIAL(TextureLod);
897 case k_subpassLoad_IntrinsicKind: return SPECIAL(SubpassLoad);
898
899 case k_textureRead_IntrinsicKind: return SPECIAL(TextureRead);
900 case k_textureWrite_IntrinsicKind: return SPECIAL(TextureWrite);
901 case k_textureWidth_IntrinsicKind: return SPECIAL(TextureWidth);
902 case k_textureHeight_IntrinsicKind: return SPECIAL(TextureHeight);
903
904 case k_floatBitsToInt_IntrinsicKind: return ALL_SPIRV(Bitcast);
905 case k_floatBitsToUint_IntrinsicKind: return ALL_SPIRV(Bitcast);
906 case k_intBitsToFloat_IntrinsicKind: return ALL_SPIRV(Bitcast);
907 case k_uintBitsToFloat_IntrinsicKind: return ALL_SPIRV(Bitcast);
908
909 case k_any_IntrinsicKind: return BOOL_SPIRV(Any);
910 case k_all_IntrinsicKind: return BOOL_SPIRV(All);
911 case k_not_IntrinsicKind: return BOOL_SPIRV(LogicalNot);
912
913 case k_equal_IntrinsicKind:
914 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
915 SpvOpFOrdEqual,
916 SpvOpIEqual,
917 SpvOpIEqual,
918 SpvOpLogicalEqual};
919 case k_notEqual_IntrinsicKind:
920 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
921 SpvOpFUnordNotEqual,
922 SpvOpINotEqual,
923 SpvOpINotEqual,
924 SpvOpLogicalNotEqual};
925 case k_lessThan_IntrinsicKind:
926 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
927 SpvOpFOrdLessThan,
928 SpvOpSLessThan,
929 SpvOpULessThan,
930 SpvOpUndef};
931 case k_lessThanEqual_IntrinsicKind:
932 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
933 SpvOpFOrdLessThanEqual,
934 SpvOpSLessThanEqual,
935 SpvOpULessThanEqual,
936 SpvOpUndef};
937 case k_greaterThan_IntrinsicKind:
938 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
939 SpvOpFOrdGreaterThan,
940 SpvOpSGreaterThan,
941 SpvOpUGreaterThan,
942 SpvOpUndef};
943 case k_greaterThanEqual_IntrinsicKind:
944 return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
945 SpvOpFOrdGreaterThanEqual,
946 SpvOpSGreaterThanEqual,
947 SpvOpUGreaterThanEqual,
948 SpvOpUndef};
949
950 case k_atomicAdd_IntrinsicKind: return SPECIAL(AtomicAdd);
951 case k_atomicLoad_IntrinsicKind: return SPECIAL(AtomicLoad);
952 case k_atomicStore_IntrinsicKind: return SPECIAL(AtomicStore);
953
954 case k_storageBarrier_IntrinsicKind: return SPECIAL(StorageBarrier);
955 case k_workgroupBarrier_IntrinsicKind: return SPECIAL(WorkgroupBarrier);
956
957 default:
958 return Intrinsic{kInvalid_IntrinsicOpcodeKind, 0, 0, 0, 0};
959 }
960 }
961
writeWord(int32_t word,OutputStream & out)962 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
963 out.write((const char*) &word, sizeof(word));
964 }
965
is_float(const Type & type)966 static bool is_float(const Type& type) {
967 return (type.isScalar() || type.isVector() || type.isMatrix()) &&
968 type.componentType().isFloat();
969 }
970
is_signed(const Type & type)971 static bool is_signed(const Type& type) {
972 return (type.isScalar() || type.isVector()) && type.componentType().isSigned();
973 }
974
is_unsigned(const Type & type)975 static bool is_unsigned(const Type& type) {
976 return (type.isScalar() || type.isVector()) && type.componentType().isUnsigned();
977 }
978
is_bool(const Type & type)979 static bool is_bool(const Type& type) {
980 return (type.isScalar() || type.isVector()) && type.componentType().isBoolean();
981 }
982
983 template <typename T>
pick_by_type(const Type & type,T ifFloat,T ifInt,T ifUInt,T ifBool)984 static T pick_by_type(const Type& type, T ifFloat, T ifInt, T ifUInt, T ifBool) {
985 if (is_float(type)) {
986 return ifFloat;
987 }
988 if (is_signed(type)) {
989 return ifInt;
990 }
991 if (is_unsigned(type)) {
992 return ifUInt;
993 }
994 if (is_bool(type)) {
995 return ifBool;
996 }
997 SkDEBUGFAIL("unrecognized type");
998 return ifFloat;
999 }
1000
is_out(ModifierFlags f)1001 static bool is_out(ModifierFlags f) {
1002 return SkToBool(f & ModifierFlag::kOut);
1003 }
1004
is_in(ModifierFlags f)1005 static bool is_in(ModifierFlags f) {
1006 if (f & ModifierFlag::kIn) {
1007 return true; // `in` and `inout` both count
1008 }
1009 // If neither in/out flag is set, the type is implicitly `in`.
1010 return !SkToBool(f & ModifierFlag::kOut);
1011 }
1012
is_control_flow_op(SpvOp_ op)1013 static bool is_control_flow_op(SpvOp_ op) {
1014 switch (op) {
1015 case SpvOpReturn:
1016 case SpvOpReturnValue:
1017 case SpvOpKill:
1018 case SpvOpSwitch:
1019 case SpvOpBranch:
1020 case SpvOpBranchConditional:
1021 return true;
1022 default:
1023 return false;
1024 }
1025 }
1026
is_globally_reachable_op(SpvOp_ op)1027 static bool is_globally_reachable_op(SpvOp_ op) {
1028 switch (op) {
1029 case SpvOpConstant:
1030 case SpvOpConstantTrue:
1031 case SpvOpConstantFalse:
1032 case SpvOpConstantComposite:
1033 case SpvOpTypeVoid:
1034 case SpvOpTypeInt:
1035 case SpvOpTypeFloat:
1036 case SpvOpTypeBool:
1037 case SpvOpTypeVector:
1038 case SpvOpTypeMatrix:
1039 case SpvOpTypeArray:
1040 case SpvOpTypePointer:
1041 case SpvOpTypeFunction:
1042 case SpvOpTypeRuntimeArray:
1043 case SpvOpTypeStruct:
1044 case SpvOpTypeImage:
1045 case SpvOpTypeSampledImage:
1046 case SpvOpTypeSampler:
1047 case SpvOpVariable:
1048 case SpvOpFunction:
1049 case SpvOpFunctionParameter:
1050 case SpvOpFunctionEnd:
1051 case SpvOpExecutionMode:
1052 case SpvOpMemoryModel:
1053 case SpvOpCapability:
1054 case SpvOpExtInstImport:
1055 case SpvOpEntryPoint:
1056 case SpvOpSource:
1057 case SpvOpSourceExtension:
1058 case SpvOpName:
1059 case SpvOpMemberName:
1060 case SpvOpDecorate:
1061 case SpvOpMemberDecorate:
1062 return true;
1063 default:
1064 return false;
1065 }
1066 }
1067
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)1068 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
1069 SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
1070 SkASSERT(opCode != SpvOpUndef);
1071 bool foundDeadCode = false;
1072 if (is_control_flow_op(opCode)) {
1073 // This instruction causes us to leave the current block.
1074 foundDeadCode = (fCurrentBlock == 0);
1075 fCurrentBlock = 0;
1076 } else if (!is_globally_reachable_op(opCode)) {
1077 foundDeadCode = (fCurrentBlock == 0);
1078 }
1079
1080 if (foundDeadCode) {
1081 // We just encountered dead code--an instruction that don't have an associated block.
1082 // Synthesize a label if this happens; this is necessary to satisfy the validator.
1083 this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
1084 }
1085
1086 this->writeWord((length << 16) | opCode, out);
1087 }
1088
writeLabel(SpvId label,StraightLineLabelType,OutputStream & out)1089 void SPIRVCodeGenerator::writeLabel(SpvId label, StraightLineLabelType, OutputStream& out) {
1090 // The straight-line label type is not important; in any case, no caches are invalidated.
1091 SkASSERT(!fCurrentBlock);
1092 fCurrentBlock = label;
1093 this->writeInstruction(SpvOpLabel, label, out);
1094 }
1095
writeLabel(SpvId label,BranchingLabelType type,ConditionalOpCounts ops,OutputStream & out)1096 void SPIRVCodeGenerator::writeLabel(SpvId label, BranchingLabelType type,
1097 ConditionalOpCounts ops, OutputStream& out) {
1098 switch (type) {
1099 case kBranchIsBelow:
1100 case kBranchesOnBothSides:
1101 // With a backward or bidirectional branch, we haven't seen the code between the label
1102 // and the branch yet, so any stored value is potentially suspect. Without scanning
1103 // ahead to check, the only safe option is to ditch the store cache entirely.
1104 fStoreCache.reset();
1105 [[fallthrough]];
1106
1107 case kBranchIsAbove:
1108 // With a forward branch, we can rely on stores that we had cached at the start of the
1109 // statement/expression, if they haven't been touched yet. Anything newer than that is
1110 // pruned.
1111 this->pruneConditionalOps(ops);
1112 break;
1113 }
1114
1115 // Emit the label.
1116 this->writeLabel(label, kBranchlessBlock, out);
1117 }
1118
writeInstruction(SpvOp_ opCode,OutputStream & out)1119 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
1120 this->writeOpCode(opCode, 1, out);
1121 }
1122
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)1123 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
1124 this->writeOpCode(opCode, 2, out);
1125 this->writeWord(word1, out);
1126 }
1127
writeString(std::string_view s,OutputStream & out)1128 void SPIRVCodeGenerator::writeString(std::string_view s, OutputStream& out) {
1129 out.write(s.data(), s.length());
1130 switch (s.length() % 4) {
1131 case 1:
1132 out.write8(0);
1133 [[fallthrough]];
1134 case 2:
1135 out.write8(0);
1136 [[fallthrough]];
1137 case 3:
1138 out.write8(0);
1139 break;
1140 default:
1141 this->writeWord(0, out);
1142 break;
1143 }
1144 }
1145
writeInstruction(SpvOp_ opCode,std::string_view string,OutputStream & out)1146 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, std::string_view string,
1147 OutputStream& out) {
1148 this->writeOpCode(opCode, 1 + (string.length() + 4) / 4, out);
1149 this->writeString(string, out);
1150 }
1151
writeInstruction(SpvOp_ opCode,int32_t word1,std::string_view string,OutputStream & out)1152 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
1153 OutputStream& out) {
1154 this->writeOpCode(opCode, 2 + (string.length() + 4) / 4, out);
1155 this->writeWord(word1, out);
1156 this->writeString(string, out);
1157 }
1158
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,std::string_view string,OutputStream & out)1159 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1160 std::string_view string, OutputStream& out) {
1161 this->writeOpCode(opCode, 3 + (string.length() + 4) / 4, out);
1162 this->writeWord(word1, out);
1163 this->writeWord(word2, out);
1164 this->writeString(string, out);
1165 }
1166
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)1167 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1168 OutputStream& out) {
1169 this->writeOpCode(opCode, 3, out);
1170 this->writeWord(word1, out);
1171 this->writeWord(word2, out);
1172 }
1173
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)1174 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1175 int32_t word3, OutputStream& out) {
1176 this->writeOpCode(opCode, 4, out);
1177 this->writeWord(word1, out);
1178 this->writeWord(word2, out);
1179 this->writeWord(word3, out);
1180 }
1181
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)1182 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1183 int32_t word3, int32_t word4, OutputStream& out) {
1184 this->writeOpCode(opCode, 5, out);
1185 this->writeWord(word1, out);
1186 this->writeWord(word2, out);
1187 this->writeWord(word3, out);
1188 this->writeWord(word4, out);
1189 }
1190
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)1191 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1192 int32_t word3, int32_t word4, int32_t word5,
1193 OutputStream& out) {
1194 this->writeOpCode(opCode, 6, out);
1195 this->writeWord(word1, out);
1196 this->writeWord(word2, out);
1197 this->writeWord(word3, out);
1198 this->writeWord(word4, out);
1199 this->writeWord(word5, out);
1200 }
1201
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)1202 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1203 int32_t word3, int32_t word4, int32_t word5,
1204 int32_t word6, OutputStream& out) {
1205 this->writeOpCode(opCode, 7, out);
1206 this->writeWord(word1, out);
1207 this->writeWord(word2, out);
1208 this->writeWord(word3, out);
1209 this->writeWord(word4, out);
1210 this->writeWord(word5, out);
1211 this->writeWord(word6, out);
1212 }
1213
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)1214 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1215 int32_t word3, int32_t word4, int32_t word5,
1216 int32_t word6, int32_t word7, OutputStream& out) {
1217 this->writeOpCode(opCode, 8, out);
1218 this->writeWord(word1, out);
1219 this->writeWord(word2, out);
1220 this->writeWord(word3, out);
1221 this->writeWord(word4, out);
1222 this->writeWord(word5, out);
1223 this->writeWord(word6, out);
1224 this->writeWord(word7, out);
1225 }
1226
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)1227 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1228 int32_t word3, int32_t word4, int32_t word5,
1229 int32_t word6, int32_t word7, int32_t word8,
1230 OutputStream& out) {
1231 this->writeOpCode(opCode, 9, out);
1232 this->writeWord(word1, out);
1233 this->writeWord(word2, out);
1234 this->writeWord(word3, out);
1235 this->writeWord(word4, out);
1236 this->writeWord(word5, out);
1237 this->writeWord(word6, out);
1238 this->writeWord(word7, out);
1239 this->writeWord(word8, out);
1240 }
1241
BuildInstructionKey(SpvOp_ opCode,const TArray<Word> & words)1242 SPIRVCodeGenerator::Instruction SPIRVCodeGenerator::BuildInstructionKey(SpvOp_ opCode,
1243 const TArray<Word>& words) {
1244 // Assemble a cache key for this instruction.
1245 Instruction key;
1246 key.fOp = opCode;
1247 key.fWords.resize(words.size());
1248 key.fResultKind = Word::Kind::kNone;
1249
1250 for (int index = 0; index < words.size(); ++index) {
1251 const Word& word = words[index];
1252 key.fWords[index] = word.fValue;
1253 if (word.isResult()) {
1254 SkASSERT(key.fResultKind == Word::Kind::kNone);
1255 key.fResultKind = word.fKind;
1256 }
1257 }
1258
1259 return key;
1260 }
1261
writeInstruction(SpvOp_ opCode,const TArray<Word> & words,OutputStream & out)1262 SpvId SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode,
1263 const TArray<Word>& words,
1264 OutputStream& out) {
1265 // writeOpLoad and writeOpStore have dedicated code.
1266 SkASSERT(opCode != SpvOpLoad);
1267 SkASSERT(opCode != SpvOpStore);
1268
1269 // If this instruction exists in our op cache, return the cached SpvId.
1270 Instruction key = BuildInstructionKey(opCode, words);
1271 if (SpvId* cachedOp = fOpCache.find(key)) {
1272 return *cachedOp;
1273 }
1274
1275 SpvId result = NA;
1276 Precision precision = Precision::kDefault;
1277
1278 switch (key.fResultKind) {
1279 case Word::Kind::kUniqueResult:
1280 // The instruction returns a SpvId, but we do not want deduplication.
1281 result = this->nextId(Precision::kDefault);
1282 fSpvIdCache.set(result, key);
1283 break;
1284
1285 case Word::Kind::kNone:
1286 // The instruction doesn't return a SpvId, but we can still cache and deduplicate it.
1287 fOpCache.set(key, result);
1288 break;
1289
1290 case Word::Kind::kRelaxedPrecisionResult:
1291 precision = Precision::kRelaxed;
1292 [[fallthrough]];
1293
1294 case Word::Kind::kKeyedResult:
1295 [[fallthrough]];
1296
1297 case Word::Kind::kDefaultPrecisionResult:
1298 // Consume a new SpvId.
1299 result = this->nextId(precision);
1300 fOpCache.set(key, result);
1301 fSpvIdCache.set(result, key);
1302
1303 // Globally-reachable ops are not subject to the whims of flow control.
1304 if (!is_globally_reachable_op(opCode)) {
1305 fReachableOps.push_back(result);
1306 }
1307 break;
1308
1309 default:
1310 SkDEBUGFAIL("unexpected result kind");
1311 break;
1312 }
1313
1314 // Write the requested instruction.
1315 this->writeOpCode(opCode, words.size() + 1, out);
1316 for (const Word& word : words) {
1317 if (word.isResult()) {
1318 SkASSERT(result != NA);
1319 this->writeWord(result, out);
1320 } else {
1321 this->writeWord(word.fValue, out);
1322 }
1323 }
1324
1325 // Return the result.
1326 return result;
1327 }
1328
writeOpLoad(SpvId type,Precision precision,SpvId pointer,OutputStream & out)1329 SpvId SPIRVCodeGenerator::writeOpLoad(SpvId type,
1330 Precision precision,
1331 SpvId pointer,
1332 OutputStream& out) {
1333 // Look for this pointer in our load-cache.
1334 if (SpvId* cachedOp = fStoreCache.find(pointer)) {
1335 return *cachedOp;
1336 }
1337
1338 // Write the requested OpLoad instruction.
1339 SpvId result = this->nextId(precision);
1340 this->writeInstruction(SpvOpLoad, type, result, pointer, out);
1341 return result;
1342 }
1343
writeOpStore(StorageClass storageClass,SpvId pointer,SpvId value,OutputStream & out)1344 void SPIRVCodeGenerator::writeOpStore(StorageClass storageClass,
1345 SpvId pointer,
1346 SpvId value,
1347 OutputStream& out) {
1348 // Write the uncached SpvOpStore directly.
1349 this->writeInstruction(SpvOpStore, pointer, value, out);
1350
1351 if (storageClass == StorageClass::kFunction) {
1352 // Insert a pointer-to-SpvId mapping into the load cache. A writeOpLoad to this pointer will
1353 // return the cached value as-is.
1354 fStoreCache.set(pointer, value);
1355 fStoreOps.push_back(pointer);
1356 }
1357 }
1358
writeOpConstantTrue(const Type & type)1359 SpvId SPIRVCodeGenerator::writeOpConstantTrue(const Type& type) {
1360 return this->writeInstruction(SpvOpConstantTrue,
1361 Words{this->getType(type), Word::Result()},
1362 fConstantBuffer);
1363 }
1364
writeOpConstantFalse(const Type & type)1365 SpvId SPIRVCodeGenerator::writeOpConstantFalse(const Type& type) {
1366 return this->writeInstruction(SpvOpConstantFalse,
1367 Words{this->getType(type), Word::Result()},
1368 fConstantBuffer);
1369 }
1370
writeOpConstant(const Type & type,int32_t valueBits)1371 SpvId SPIRVCodeGenerator::writeOpConstant(const Type& type, int32_t valueBits) {
1372 return this->writeInstruction(
1373 SpvOpConstant,
1374 Words{this->getType(type), Word::Result(), Word::Number(valueBits)},
1375 fConstantBuffer);
1376 }
1377
writeOpConstantComposite(const Type & type,const TArray<SpvId> & values)1378 SpvId SPIRVCodeGenerator::writeOpConstantComposite(const Type& type,
1379 const TArray<SpvId>& values) {
1380 SkASSERT(values.size() == (type.isStruct() ? SkToInt(type.fields().size()) : type.columns()));
1381
1382 Words words;
1383 words.push_back(this->getType(type));
1384 words.push_back(Word::Result());
1385 for (SpvId value : values) {
1386 words.push_back(value);
1387 }
1388 return this->writeInstruction(SpvOpConstantComposite, words, fConstantBuffer);
1389 }
1390
toConstants(SpvId value,TArray<SpvId> * constants)1391 bool SPIRVCodeGenerator::toConstants(SpvId value, TArray<SpvId>* constants) {
1392 Instruction* instr = fSpvIdCache.find(value);
1393 if (!instr) {
1394 return false;
1395 }
1396 switch (instr->fOp) {
1397 case SpvOpConstant:
1398 case SpvOpConstantTrue:
1399 case SpvOpConstantFalse:
1400 constants->push_back(value);
1401 return true;
1402
1403 case SpvOpConstantComposite: // OpConstantComposite ResultType ResultID Constituents...
1404 // Start at word 2 to skip past ResultType and ResultID.
1405 for (int i = 2; i < instr->fWords.size(); ++i) {
1406 if (!this->toConstants(instr->fWords[i], constants)) {
1407 return false;
1408 }
1409 }
1410 return true;
1411
1412 default:
1413 return false;
1414 }
1415 }
1416
toConstants(SkSpan<const SpvId> values,TArray<SpvId> * constants)1417 bool SPIRVCodeGenerator::toConstants(SkSpan<const SpvId> values, TArray<SpvId>* constants) {
1418 for (SpvId value : values) {
1419 if (!this->toConstants(value, constants)) {
1420 return false;
1421 }
1422 }
1423 return true;
1424 }
1425
writeOpCompositeConstruct(const Type & type,const TArray<SpvId> & values,OutputStream & out)1426 SpvId SPIRVCodeGenerator::writeOpCompositeConstruct(const Type& type,
1427 const TArray<SpvId>& values,
1428 OutputStream& out) {
1429 // If this is a vector composed entirely of literals, write a constant-composite instead.
1430 if (type.isVector()) {
1431 STArray<4, SpvId> constants;
1432 if (this->toConstants(SkSpan(values), &constants)) {
1433 // Create a vector from literals.
1434 return this->writeOpConstantComposite(type, constants);
1435 }
1436 }
1437
1438 // If this is a matrix composed entirely of literals, constant-composite them instead.
1439 if (type.isMatrix()) {
1440 STArray<16, SpvId> constants;
1441 if (this->toConstants(SkSpan(values), &constants)) {
1442 // Create each matrix column.
1443 SkASSERT(type.isMatrix());
1444 const Type& vecType = type.columnType(fContext);
1445 STArray<4, SpvId> columnIDs;
1446 for (int index=0; index < type.columns(); ++index) {
1447 STArray<4, SpvId> columnConstants(&constants[index * type.rows()],
1448 type.rows());
1449 columnIDs.push_back(this->writeOpConstantComposite(vecType, columnConstants));
1450 }
1451 // Compose the matrix from its columns.
1452 return this->writeOpConstantComposite(type, columnIDs);
1453 }
1454 }
1455
1456 Words words;
1457 words.push_back(this->getType(type));
1458 words.push_back(Word::Result(type));
1459 for (SpvId value : values) {
1460 words.push_back(value);
1461 }
1462
1463 return this->writeInstruction(SpvOpCompositeConstruct, words, out);
1464 }
1465
resultTypeForInstruction(const Instruction & instr)1466 SPIRVCodeGenerator::Instruction* SPIRVCodeGenerator::resultTypeForInstruction(
1467 const Instruction& instr) {
1468 // This list should contain every op that we cache that has a result and result-type.
1469 // (If one is missing, we will not find some optimization opportunities.)
1470 // Generally, the result type of an op is in the 0th word, but I'm not sure if this is
1471 // universally true, so it's configurable on a per-op basis.
1472 int resultTypeWord;
1473 switch (instr.fOp) {
1474 case SpvOpConstant:
1475 case SpvOpConstantTrue:
1476 case SpvOpConstantFalse:
1477 case SpvOpConstantComposite:
1478 case SpvOpCompositeConstruct:
1479 case SpvOpCompositeExtract:
1480 case SpvOpLoad:
1481 resultTypeWord = 0;
1482 break;
1483
1484 default:
1485 return nullptr;
1486 }
1487
1488 Instruction* typeInstr = fSpvIdCache.find(instr.fWords[resultTypeWord]);
1489 SkASSERT(typeInstr);
1490 return typeInstr;
1491 }
1492
numComponentsForVecInstruction(const Instruction & instr)1493 int SPIRVCodeGenerator::numComponentsForVecInstruction(const Instruction& instr) {
1494 // If an instruction is in the op cache, its type should be as well.
1495 Instruction* typeInstr = this->resultTypeForInstruction(instr);
1496 SkASSERT(typeInstr);
1497 SkASSERT(typeInstr->fOp == SpvOpTypeVector || typeInstr->fOp == SpvOpTypeFloat ||
1498 typeInstr->fOp == SpvOpTypeInt || typeInstr->fOp == SpvOpTypeBool);
1499
1500 // For vectors, extract their column count. Scalars have one component by definition.
1501 // SpvOpTypeVector ResultID ComponentType NumComponents
1502 return (typeInstr->fOp == SpvOpTypeVector) ? typeInstr->fWords[2]
1503 : 1;
1504 }
1505
toComponent(SpvId id,int component)1506 SpvId SPIRVCodeGenerator::toComponent(SpvId id, int component) {
1507 Instruction* instr = fSpvIdCache.find(id);
1508 if (!instr) {
1509 return NA;
1510 }
1511 if (instr->fOp == SpvOpConstantComposite) {
1512 // SpvOpConstantComposite ResultType ResultID [components...]
1513 // Add 2 to the component index to skip past ResultType and ResultID.
1514 return instr->fWords[2 + component];
1515 }
1516 if (instr->fOp == SpvOpCompositeConstruct) {
1517 // SpvOpCompositeConstruct ResultType ResultID [components...]
1518 // Vectors have special rules; check to see if we are composing a vector.
1519 Instruction* composedType = fSpvIdCache.find(instr->fWords[0]);
1520 SkASSERT(composedType);
1521
1522 // When composing a non-vector, each instruction word maps 1:1 to the component index.
1523 // We can just extract out the associated component directly.
1524 if (composedType->fOp != SpvOpTypeVector) {
1525 return instr->fWords[2 + component];
1526 }
1527
1528 // When composing a vector, components can be either scalars or vectors.
1529 // This means we need to check the op type on each component. (+2 to skip ResultType/Result)
1530 for (int index = 2; index < instr->fWords.size(); ++index) {
1531 int32_t currentWord = instr->fWords[index];
1532
1533 // Retrieve the sub-instruction pointed to by OpCompositeConstruct.
1534 Instruction* subinstr = fSpvIdCache.find(currentWord);
1535 if (!subinstr) {
1536 return NA;
1537 }
1538 // If this subinstruction contains the component we're looking for...
1539 int numComponents = this->numComponentsForVecInstruction(*subinstr);
1540 if (component < numComponents) {
1541 if (numComponents == 1) {
1542 // ... it's a scalar. Return it.
1543 SkASSERT(component == 0);
1544 return currentWord;
1545 } else {
1546 // ... it's a vector. Recurse into it.
1547 return this->toComponent(currentWord, component);
1548 }
1549 }
1550 // This sub-instruction doesn't contain our component. Keep walking forward.
1551 component -= numComponents;
1552 }
1553 SkDEBUGFAIL("component index goes past the end of this composite value");
1554 return NA;
1555 }
1556 return NA;
1557 }
1558
writeOpCompositeExtract(const Type & type,SpvId base,int component,OutputStream & out)1559 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
1560 SpvId base,
1561 int component,
1562 OutputStream& out) {
1563 // If the base op is a composite, we can extract from it directly.
1564 SpvId result = this->toComponent(base, component);
1565 if (result != NA) {
1566 return result;
1567 }
1568 return this->writeInstruction(
1569 SpvOpCompositeExtract,
1570 {this->getType(type), Word::Result(type), base, Word::Number(component)},
1571 out);
1572 }
1573
writeOpCompositeExtract(const Type & type,SpvId base,int componentA,int componentB,OutputStream & out)1574 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
1575 SpvId base,
1576 int componentA,
1577 int componentB,
1578 OutputStream& out) {
1579 // If the base op is a composite, we can extract from it directly.
1580 SpvId result = this->toComponent(base, componentA);
1581 if (result != NA) {
1582 return this->writeOpCompositeExtract(type, result, componentB, out);
1583 }
1584 return this->writeInstruction(SpvOpCompositeExtract,
1585 {this->getType(type),
1586 Word::Result(type),
1587 base,
1588 Word::Number(componentA),
1589 Word::Number(componentB)},
1590 out);
1591 }
1592
writeCapabilities(OutputStream & out)1593 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
1594 for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
1595 if (fCapabilities & bit) {
1596 this->writeInstruction(SpvOpCapability, (SpvId) i, out);
1597 }
1598 }
1599 this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
1600 }
1601
nextId(const Type * type)1602 SpvId SPIRVCodeGenerator::nextId(const Type* type) {
1603 return this->nextId(type && type->hasPrecision() && !type->highPrecision()
1604 ? Precision::kRelaxed
1605 : Precision::kDefault);
1606 }
1607
nextId(Precision precision)1608 SpvId SPIRVCodeGenerator::nextId(Precision precision) {
1609 if (precision == Precision::kRelaxed && !fProgram.fConfig->fSettings.fForceHighPrecision) {
1610 this->writeInstruction(SpvOpDecorate, fIdCount, SpvDecorationRelaxedPrecision,
1611 fDecorationBuffer);
1612 }
1613 return fIdCount++;
1614 }
1615
writeStruct(const Type & type,const MemoryLayout & memoryLayout)1616 SpvId SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout) {
1617 // If we've already written out this struct, return its existing SpvId.
1618 if (SpvId* cachedStructId = fStructMap.find(&type)) {
1619 return *cachedStructId;
1620 }
1621
1622 // Write all of the field types first, so we don't inadvertently write them while we're in the
1623 // middle of writing the struct instruction.
1624 Words words;
1625 words.push_back(Word::UniqueResult());
1626 for (const auto& f : type.fields()) {
1627 words.push_back(this->getType(*f.fType, f.fLayout, memoryLayout));
1628 }
1629 SpvId resultId = this->writeInstruction(SpvOpTypeStruct, words, fConstantBuffer);
1630 this->writeInstruction(SpvOpName, resultId, type.name(), fNameBuffer);
1631 fStructMap.set(&type, resultId);
1632
1633 size_t offset = 0;
1634 for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
1635 const Field& field = type.fields()[i];
1636 if (!memoryLayout.isSupported(*field.fType)) {
1637 fContext.fErrors->error(type.fPosition, "type '" + field.fType->displayName() +
1638 "' is not permitted here");
1639 return resultId;
1640 }
1641 size_t size = memoryLayout.size(*field.fType);
1642 size_t alignment = memoryLayout.alignment(*field.fType);
1643 const Layout& fieldLayout = field.fLayout;
1644 if (fieldLayout.fOffset >= 0) {
1645 if (fieldLayout.fOffset < (int) offset) {
1646 fContext.fErrors->error(field.fPosition, "offset of field '" +
1647 std::string(field.fName) + "' must be at least " + std::to_string(offset));
1648 }
1649 if (fieldLayout.fOffset % alignment) {
1650 fContext.fErrors->error(field.fPosition,
1651 "offset of field '" + std::string(field.fName) +
1652 "' must be a multiple of " + std::to_string(alignment));
1653 }
1654 offset = fieldLayout.fOffset;
1655 } else {
1656 size_t mod = offset % alignment;
1657 if (mod) {
1658 offset += alignment - mod;
1659 }
1660 }
1661 this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
1662 this->writeFieldLayout(fieldLayout, resultId, i);
1663 if (field.fLayout.fBuiltin < 0) {
1664 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
1665 (SpvId) offset, fDecorationBuffer);
1666 }
1667 if (field.fType->isMatrix()) {
1668 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1669 fDecorationBuffer);
1670 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1671 (SpvId) memoryLayout.stride(*field.fType),
1672 fDecorationBuffer);
1673 }
1674 if (!field.fType->highPrecision()) {
1675 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
1676 SpvDecorationRelaxedPrecision, fDecorationBuffer);
1677 }
1678 offset += size;
1679 if ((field.fType->isArray() || field.fType->isStruct()) && offset % alignment != 0) {
1680 offset += alignment - offset % alignment;
1681 }
1682 }
1683
1684 return resultId;
1685 }
1686
getType(const Type & type)1687 SpvId SPIRVCodeGenerator::getType(const Type& type) {
1688 return this->getType(type, kDefaultTypeLayout, fDefaultMemoryLayout);
1689 }
1690
layout_flags_to_image_format(LayoutFlags flags)1691 static SpvImageFormat layout_flags_to_image_format(LayoutFlags flags) {
1692 flags &= LayoutFlag::kAllPixelFormats;
1693 switch (flags.value()) {
1694 case (int)LayoutFlag::kRGBA8:
1695 return SpvImageFormatRgba8;
1696
1697 case (int)LayoutFlag::kRGBA32F:
1698 return SpvImageFormatRgba32f;
1699
1700 case (int)LayoutFlag::kR32F:
1701 return SpvImageFormatR32f;
1702
1703 default:
1704 return SpvImageFormatUnknown;
1705 }
1706
1707 SkUNREACHABLE;
1708 }
1709
getType(const Type & rawType,const Layout & typeLayout,const MemoryLayout & memoryLayout)1710 SpvId SPIRVCodeGenerator::getType(const Type& rawType,
1711 const Layout& typeLayout,
1712 const MemoryLayout& memoryLayout) {
1713 const Type* type = &rawType;
1714
1715 switch (type->typeKind()) {
1716 case Type::TypeKind::kVoid: {
1717 return this->writeInstruction(SpvOpTypeVoid, Words{Word::Result()}, fConstantBuffer);
1718 }
1719 case Type::TypeKind::kScalar:
1720 case Type::TypeKind::kLiteral: {
1721 if (type->isBoolean()) {
1722 return this->writeInstruction(SpvOpTypeBool, {Word::Result()}, fConstantBuffer);
1723 }
1724 if (type->isSigned()) {
1725 return this->writeInstruction(
1726 SpvOpTypeInt,
1727 Words{Word::Result(), Word::Number(32), Word::Number(1)},
1728 fConstantBuffer);
1729 }
1730 if (type->isUnsigned()) {
1731 return this->writeInstruction(
1732 SpvOpTypeInt,
1733 Words{Word::Result(), Word::Number(32), Word::Number(0)},
1734 fConstantBuffer);
1735 }
1736 if (type->isFloat()) {
1737 return this->writeInstruction(
1738 SpvOpTypeFloat,
1739 Words{Word::Result(), Word::Number(32)},
1740 fConstantBuffer);
1741 }
1742 SkDEBUGFAILF("unrecognized scalar type '%s'", type->description().c_str());
1743 return NA;
1744 }
1745 case Type::TypeKind::kVector: {
1746 SpvId scalarTypeId = this->getType(type->componentType(), typeLayout, memoryLayout);
1747 return this->writeInstruction(
1748 SpvOpTypeVector,
1749 Words{Word::Result(), scalarTypeId, Word::Number(type->columns())},
1750 fConstantBuffer);
1751 }
1752 case Type::TypeKind::kMatrix: {
1753 SpvId vectorTypeId = this->getType(IndexExpression::IndexType(fContext, *type),
1754 typeLayout,
1755 memoryLayout);
1756 return this->writeInstruction(
1757 SpvOpTypeMatrix,
1758 Words{Word::Result(), vectorTypeId, Word::Number(type->columns())},
1759 fConstantBuffer);
1760 }
1761 case Type::TypeKind::kArray: {
1762 const MemoryLayout arrayMemoryLayout =
1763 fCaps.fForceStd430ArrayLayout
1764 ? MemoryLayout(MemoryLayout::Standard::k430)
1765 : memoryLayout;
1766
1767 if (!arrayMemoryLayout.isSupported(*type)) {
1768 fContext.fErrors->error(type->fPosition, "type '" + type->displayName() +
1769 "' is not permitted here");
1770 return NA;
1771 }
1772 size_t stride = arrayMemoryLayout.stride(*type);
1773 SpvId typeId = this->getType(type->componentType(), typeLayout, arrayMemoryLayout);
1774 SpvId result = NA;
1775 if (type->isUnsizedArray()) {
1776 result = this->writeInstruction(SpvOpTypeRuntimeArray,
1777 Words{Word::KeyedResult(stride), typeId},
1778 fConstantBuffer);
1779 } else {
1780 SpvId countId = this->writeLiteral(type->columns(), *fContext.fTypes.fInt);
1781 result = this->writeInstruction(SpvOpTypeArray,
1782 Words{Word::KeyedResult(stride), typeId, countId},
1783 fConstantBuffer);
1784 }
1785 this->writeInstruction(SpvOpDecorate,
1786 {result, SpvDecorationArrayStride, Word::Number(stride)},
1787 fDecorationBuffer);
1788 return result;
1789 }
1790 case Type::TypeKind::kStruct: {
1791 return this->writeStruct(*type, memoryLayout);
1792 }
1793 case Type::TypeKind::kSeparateSampler: {
1794 return this->writeInstruction(SpvOpTypeSampler, Words{Word::Result()}, fConstantBuffer);
1795 }
1796 case Type::TypeKind::kSampler: {
1797 if (SpvDimBuffer == type->dimensions()) {
1798 fCapabilities |= 1ULL << SpvCapabilitySampledBuffer;
1799 }
1800 SpvId imageTypeId = this->getType(type->textureType(), typeLayout, memoryLayout);
1801 return this->writeInstruction(SpvOpTypeSampledImage,
1802 Words{Word::Result(), imageTypeId},
1803 fConstantBuffer);
1804 }
1805 case Type::TypeKind::kTexture: {
1806 SpvId floatTypeId = this->getType(*fContext.fTypes.fFloat,
1807 kDefaultTypeLayout,
1808 memoryLayout);
1809
1810 bool sampled = (type->textureAccess() == Type::TextureAccess::kSample);
1811 SpvImageFormat format = (!sampled && type->dimensions() != SpvDimSubpassData)
1812 ? layout_flags_to_image_format(typeLayout.fFlags)
1813 : SpvImageFormatUnknown;
1814
1815 return this->writeInstruction(SpvOpTypeImage,
1816 Words{Word::Result(),
1817 floatTypeId,
1818 Word::Number(type->dimensions()),
1819 Word::Number(type->isDepth()),
1820 Word::Number(type->isArrayedTexture()),
1821 Word::Number(type->isMultisampled()),
1822 Word::Number(sampled ? 1 : 2),
1823 format},
1824 fConstantBuffer);
1825 }
1826 case Type::TypeKind::kAtomic: {
1827 // SkSL currently only supports the atomicUint type.
1828 SkASSERT(type->matches(*fContext.fTypes.fAtomicUInt));
1829 // SPIR-V doesn't have atomic types. Rather, it allows atomic operations on primitive
1830 // types. The SPIR-V type of an SkSL atomic is simply the underlying type.
1831 return this->writeInstruction(SpvOpTypeInt,
1832 Words{Word::Result(), Word::Number(32), Word::Number(0)},
1833 fConstantBuffer);
1834 }
1835 default: {
1836 SkDEBUGFAILF("invalid type: %s", type->description().c_str());
1837 return NA;
1838 }
1839 }
1840 }
1841
getFunctionType(const FunctionDeclaration & function)1842 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
1843 Words words;
1844 words.push_back(Word::Result());
1845 words.push_back(this->getType(function.returnType()));
1846 for (const Variable* parameter : function.parameters()) {
1847 bool paramIsSpecialized = fActiveSpecialization && fActiveSpecialization->find(parameter);
1848 if (fUseTextureSamplerPairs && parameter->type().isSampler()) {
1849 words.push_back(this->getFunctionParameterType(parameter->type().textureType(),
1850 parameter->layout()));
1851 if (!paramIsSpecialized) {
1852 words.push_back(this->getFunctionParameterType(*fContext.fTypes.fSampler,
1853 kDefaultTypeLayout));
1854 }
1855 } else if (!paramIsSpecialized) {
1856 words.push_back(this->getFunctionParameterType(parameter->type(), parameter->layout()));
1857 }
1858 }
1859 return this->writeInstruction(SpvOpTypeFunction, words, fConstantBuffer);
1860 }
1861
getFunctionParameterType(const Type & parameterType,const Layout & parameterLayout)1862 SpvId SPIRVCodeGenerator::getFunctionParameterType(const Type& parameterType,
1863 const Layout& parameterLayout) {
1864 // glslang treats all function arguments as pointers whether they need to be or
1865 // not. I was initially puzzled by this until I ran bizarre failures with certain
1866 // patterns of function calls and control constructs, as exemplified by this minimal
1867 // failure case:
1868 //
1869 // void sphere(float x) {
1870 // }
1871 //
1872 // void map() {
1873 // sphere(1.0);
1874 // }
1875 //
1876 // void main() {
1877 // for (int i = 0; i < 1; i++) {
1878 // map();
1879 // }
1880 // }
1881 //
1882 // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
1883 // crashes. Making it take a float* and storing the argument in a temporary variable,
1884 // as glslang does, fixes it.
1885 //
1886 // The consensus among shader compiler authors seems to be that GPU driver generally don't
1887 // handle value-based parameters consistently. It is highly likely that they fit their
1888 // implementations to conform to glslang. We take care to do so ourselves.
1889 //
1890 // Our implementation first stores every parameter value into a function storage-class pointer
1891 // before calling a function. The exception is for opaque handle types (samplers and textures)
1892 // which must be stored in a pointer with UniformConstant storage-class. This prevents
1893 // unnecessary temporaries (becuase opaque handles are always rooted in a pointer variable),
1894 // matches glslang's behavior, and translates into WGSL more easily when targeting Dawn.
1895 StorageClass storageClass;
1896 if (parameterType.typeKind() == Type::TypeKind::kSampler ||
1897 parameterType.typeKind() == Type::TypeKind::kSeparateSampler ||
1898 parameterType.typeKind() == Type::TypeKind::kTexture) {
1899 storageClass = StorageClass::kUniformConstant;
1900 } else {
1901 storageClass = StorageClass::kFunction;
1902 }
1903 return this->getPointerType(parameterType,
1904 parameterLayout,
1905 this->memoryLayoutForStorageClass(storageClass),
1906 storageClass);
1907 }
1908
getPointerType(const Type & type,StorageClass storageClass)1909 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, StorageClass storageClass) {
1910 return this->getPointerType(type,
1911 kDefaultTypeLayout,
1912 this->memoryLayoutForStorageClass(storageClass),
1913 storageClass);
1914 }
1915
getPointerType(const Type & type,const Layout & typeLayout,const MemoryLayout & memoryLayout,StorageClass storageClass)1916 SpvId SPIRVCodeGenerator::getPointerType(const Type& type,
1917 const Layout& typeLayout,
1918 const MemoryLayout& memoryLayout,
1919 StorageClass storageClass) {
1920 return this->writeInstruction(SpvOpTypePointer,
1921 Words{Word::Result(),
1922 Word::Number(get_storage_class_spv_id(storageClass)),
1923 this->getType(type, typeLayout, memoryLayout)},
1924 fConstantBuffer);
1925 }
1926
writeExpression(const Expression & expr,OutputStream & out)1927 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
1928 switch (expr.kind()) {
1929 case Expression::Kind::kBinary:
1930 return this->writeBinaryExpression(expr.as<BinaryExpression>(), out);
1931 case Expression::Kind::kConstructorArrayCast:
1932 return this->writeExpression(*expr.as<ConstructorArrayCast>().argument(), out);
1933 case Expression::Kind::kConstructorArray:
1934 case Expression::Kind::kConstructorStruct:
1935 return this->writeCompositeConstructor(expr.asAnyConstructor(), out);
1936 case Expression::Kind::kConstructorDiagonalMatrix:
1937 return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
1938 case Expression::Kind::kConstructorMatrixResize:
1939 return this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(), out);
1940 case Expression::Kind::kConstructorScalarCast:
1941 return this->writeConstructorScalarCast(expr.as<ConstructorScalarCast>(), out);
1942 case Expression::Kind::kConstructorSplat:
1943 return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out);
1944 case Expression::Kind::kConstructorCompound:
1945 return this->writeConstructorCompound(expr.as<ConstructorCompound>(), out);
1946 case Expression::Kind::kConstructorCompoundCast:
1947 return this->writeConstructorCompoundCast(expr.as<ConstructorCompoundCast>(), out);
1948 case Expression::Kind::kEmpty:
1949 return NA;
1950 case Expression::Kind::kFieldAccess:
1951 return this->writeFieldAccess(expr.as<FieldAccess>(), out);
1952 case Expression::Kind::kFunctionCall:
1953 return this->writeFunctionCall(expr.as<FunctionCall>(), out);
1954 case Expression::Kind::kLiteral:
1955 return this->writeLiteral(expr.as<Literal>());
1956 case Expression::Kind::kPrefix:
1957 return this->writePrefixExpression(expr.as<PrefixExpression>(), out);
1958 case Expression::Kind::kPostfix:
1959 return this->writePostfixExpression(expr.as<PostfixExpression>(), out);
1960 case Expression::Kind::kSwizzle:
1961 return this->writeSwizzle(expr.as<Swizzle>(), out);
1962 case Expression::Kind::kVariableReference:
1963 return this->writeVariableReference(expr.as<VariableReference>(), out);
1964 case Expression::Kind::kTernary:
1965 return this->writeTernaryExpression(expr.as<TernaryExpression>(), out);
1966 case Expression::Kind::kIndex:
1967 return this->writeIndexExpression(expr.as<IndexExpression>(), out);
1968 case Expression::Kind::kSetting:
1969 return this->writeExpression(*expr.as<Setting>().toLiteral(fCaps), out);
1970 default:
1971 SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
1972 break;
1973 }
1974 return NA;
1975 }
1976
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)1977 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
1978 const FunctionDeclaration& function = c.function();
1979 Intrinsic intrinsic = this->getIntrinsic(function.intrinsicKind());
1980 if (intrinsic.opKind == kInvalid_IntrinsicOpcodeKind) {
1981 fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" + function.description() +
1982 "'");
1983 return NA;
1984 }
1985 const ExpressionArray& arguments = c.arguments();
1986 int32_t intrinsicId = intrinsic.floatOp;
1987 if (!arguments.empty()) {
1988 const Type& type = arguments[0]->type();
1989 if (intrinsic.opKind == kSpecial_IntrinsicOpcodeKind) {
1990 // Keep the default float op.
1991 } else {
1992 intrinsicId = pick_by_type(type, intrinsic.floatOp, intrinsic.signedOp,
1993 intrinsic.unsignedOp, intrinsic.boolOp);
1994 }
1995 }
1996 switch (intrinsic.opKind) {
1997 case kGLSL_STD_450_IntrinsicOpcodeKind: {
1998 SpvId result = this->nextId(&c.type());
1999 TArray<SpvId> argumentIds;
2000 argumentIds.reserve_exact(arguments.size());
2001 std::vector<TempVar> tempVars;
2002 for (int i = 0; i < arguments.size(); i++) {
2003 this->writeFunctionCallArgument(argumentIds, c, i, &tempVars,
2004 /*specializedParams=*/nullptr, out);
2005 }
2006 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
2007 this->writeWord(this->getType(c.type()), out);
2008 this->writeWord(result, out);
2009 this->writeWord(fGLSLExtendedInstructions, out);
2010 this->writeWord(intrinsicId, out);
2011 for (SpvId id : argumentIds) {
2012 this->writeWord(id, out);
2013 }
2014 this->copyBackTempVars(tempVars, out);
2015 return result;
2016 }
2017 case kSPIRV_IntrinsicOpcodeKind: {
2018 // GLSL supports dot(float, float), but SPIR-V does not. Convert it to FMul
2019 if (intrinsicId == SpvOpDot && arguments[0]->type().isScalar()) {
2020 intrinsicId = SpvOpFMul;
2021 }
2022 SpvId result = this->nextId(&c.type());
2023 TArray<SpvId> argumentIds;
2024 argumentIds.reserve_exact(arguments.size());
2025 std::vector<TempVar> tempVars;
2026 for (int i = 0; i < arguments.size(); i++) {
2027 this->writeFunctionCallArgument(argumentIds, c, i, &tempVars,
2028 /*specializedParams=*/nullptr, out);
2029 }
2030 if (!c.type().isVoid()) {
2031 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
2032 this->writeWord(this->getType(c.type()), out);
2033 this->writeWord(result, out);
2034 } else {
2035 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
2036 }
2037 for (SpvId id : argumentIds) {
2038 this->writeWord(id, out);
2039 }
2040 this->copyBackTempVars(tempVars, out);
2041 return result;
2042 }
2043 case kSpecial_IntrinsicOpcodeKind:
2044 return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
2045 default:
2046 fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" +
2047 function.description() + "'");
2048 return NA;
2049 }
2050 }
2051
vectorize(const Expression & arg,int vectorSize,OutputStream & out)2052 SpvId SPIRVCodeGenerator::vectorize(const Expression& arg, int vectorSize, OutputStream& out) {
2053 SkASSERT(vectorSize >= 1 && vectorSize <= 4);
2054 const Type& argType = arg.type();
2055 if (argType.isScalar() && vectorSize > 1) {
2056 SpvId argID = this->writeExpression(arg, out);
2057 return this->splat(argType.toCompound(fContext, vectorSize, /*rows=*/1), argID, out);
2058 }
2059
2060 SkASSERT(vectorSize == argType.columns());
2061 return this->writeExpression(arg, out);
2062 }
2063
vectorize(const ExpressionArray & args,OutputStream & out)2064 TArray<SpvId> SPIRVCodeGenerator::vectorize(const ExpressionArray& args, OutputStream& out) {
2065 int vectorSize = 1;
2066 for (const auto& a : args) {
2067 if (a->type().isVector()) {
2068 if (vectorSize > 1) {
2069 SkASSERT(a->type().columns() == vectorSize);
2070 } else {
2071 vectorSize = a->type().columns();
2072 }
2073 }
2074 }
2075 TArray<SpvId> result;
2076 result.reserve_exact(args.size());
2077 for (const auto& arg : args) {
2078 result.push_back(this->vectorize(*arg, vectorSize, out));
2079 }
2080 return result;
2081 }
2082
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const TArray<SpvId> & args,OutputStream & out)2083 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
2084 SpvId signedInst, SpvId unsignedInst,
2085 const TArray<SpvId>& args,
2086 OutputStream& out) {
2087 this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
2088 this->writeWord(this->getType(type), out);
2089 this->writeWord(id, out);
2090 this->writeWord(fGLSLExtendedInstructions, out);
2091 this->writeWord(pick_by_type(type, floatInst, signedInst, unsignedInst, NA), out);
2092 for (SpvId a : args) {
2093 this->writeWord(a, out);
2094 }
2095 }
2096
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)2097 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
2098 OutputStream& out) {
2099 const ExpressionArray& arguments = c.arguments();
2100 const Type& callType = c.type();
2101 SpvId result = this->nextId(nullptr);
2102 switch (kind) {
2103 case kAtan_SpecialIntrinsic: {
2104 STArray<2, SpvId> argumentIds;
2105 for (const std::unique_ptr<Expression>& arg : arguments) {
2106 argumentIds.push_back(this->writeExpression(*arg, out));
2107 }
2108 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
2109 this->writeWord(this->getType(callType), out);
2110 this->writeWord(result, out);
2111 this->writeWord(fGLSLExtendedInstructions, out);
2112 this->writeWord(argumentIds.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
2113 for (SpvId id : argumentIds) {
2114 this->writeWord(id, out);
2115 }
2116 break;
2117 }
2118 case kSampledImage_SpecialIntrinsic: {
2119 SkASSERT(arguments.size() == 2);
2120 SpvId img = this->writeExpression(*arguments[0], out);
2121 SpvId sampler = this->writeExpression(*arguments[1], out);
2122 this->writeInstruction(SpvOpSampledImage,
2123 this->getType(callType),
2124 result,
2125 img,
2126 sampler,
2127 out);
2128 break;
2129 }
2130 case kSubpassLoad_SpecialIntrinsic: {
2131 SpvId img = this->writeExpression(*arguments[0], out);
2132 ExpressionArray args;
2133 args.reserve_exact(2);
2134 args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
2135 args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
2136 ConstructorCompound ctor(Position(), *fContext.fTypes.fInt2, std::move(args));
2137 SpvId coords = this->writeExpression(ctor, out);
2138 if (arguments.size() == 1) {
2139 this->writeInstruction(SpvOpImageRead,
2140 this->getType(callType),
2141 result,
2142 img,
2143 coords,
2144 out);
2145 } else {
2146 SkASSERT(arguments.size() == 2);
2147 SpvId sample = this->writeExpression(*arguments[1], out);
2148 this->writeInstruction(SpvOpImageRead,
2149 this->getType(callType),
2150 result,
2151 img,
2152 coords,
2153 SpvImageOperandsSampleMask,
2154 sample,
2155 out);
2156 }
2157 break;
2158 }
2159 case kTexture_SpecialIntrinsic: {
2160 SpvOp_ op = SpvOpImageSampleImplicitLod;
2161 const Type& arg1Type = arguments[1]->type();
2162 switch (arguments[0]->type().dimensions()) {
2163 case SpvDim1D:
2164 if (arg1Type.matches(*fContext.fTypes.fFloat2)) {
2165 op = SpvOpImageSampleProjImplicitLod;
2166 } else {
2167 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat));
2168 }
2169 break;
2170 case SpvDim2D:
2171 if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
2172 op = SpvOpImageSampleProjImplicitLod;
2173 } else {
2174 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
2175 }
2176 break;
2177 case SpvDim3D:
2178 if (arg1Type.matches(*fContext.fTypes.fFloat4)) {
2179 op = SpvOpImageSampleProjImplicitLod;
2180 } else {
2181 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat3));
2182 }
2183 break;
2184 case SpvDimCube: // fall through
2185 case SpvDimRect: // fall through
2186 case SpvDimBuffer: // fall through
2187 case SpvDimSubpassData:
2188 break;
2189 }
2190 SpvId type = this->getType(callType);
2191 SpvId sampler = this->writeExpression(*arguments[0], out);
2192 SpvId uv = this->writeExpression(*arguments[1], out);
2193 if (arguments.size() == 3) {
2194 this->writeInstruction(op, type, result, sampler, uv,
2195 SpvImageOperandsBiasMask,
2196 this->writeExpression(*arguments[2], out),
2197 out);
2198 } else {
2199 SkASSERT(arguments.size() == 2);
2200 if (fProgram.fConfig->fSettings.fSharpenTextures) {
2201 SpvId lodBias = this->writeLiteral(kSharpenTexturesBias,
2202 *fContext.fTypes.fFloat);
2203 this->writeInstruction(op, type, result, sampler, uv,
2204 SpvImageOperandsBiasMask, lodBias, out);
2205 } else {
2206 this->writeInstruction(op, type, result, sampler, uv,
2207 out);
2208 }
2209 }
2210 break;
2211 }
2212 case kTextureGrad_SpecialIntrinsic: {
2213 SpvOp_ op = SpvOpImageSampleExplicitLod;
2214 SkASSERT(arguments.size() == 4);
2215 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2216 SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fFloat2));
2217 SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat2));
2218 SkASSERT(arguments[3]->type().matches(*fContext.fTypes.fFloat2));
2219 SpvId type = this->getType(callType);
2220 SpvId sampler = this->writeExpression(*arguments[0], out);
2221 SpvId uv = this->writeExpression(*arguments[1], out);
2222 SpvId dPdx = this->writeExpression(*arguments[2], out);
2223 SpvId dPdy = this->writeExpression(*arguments[3], out);
2224 this->writeInstruction(op, type, result, sampler, uv, SpvImageOperandsGradMask,
2225 dPdx, dPdy, out);
2226 break;
2227 }
2228 case kTextureLod_SpecialIntrinsic: {
2229 SpvOp_ op = SpvOpImageSampleExplicitLod;
2230 SkASSERT(arguments.size() == 3);
2231 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2232 SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat));
2233 const Type& arg1Type = arguments[1]->type();
2234 if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
2235 op = SpvOpImageSampleProjExplicitLod;
2236 } else {
2237 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
2238 }
2239 SpvId type = this->getType(callType);
2240 SpvId sampler = this->writeExpression(*arguments[0], out);
2241 SpvId uv = this->writeExpression(*arguments[1], out);
2242 this->writeInstruction(op, type, result, sampler, uv,
2243 SpvImageOperandsLodMask,
2244 this->writeExpression(*arguments[2], out),
2245 out);
2246 break;
2247 }
2248 case kTextureRead_SpecialIntrinsic: {
2249 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2250 SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fUInt2));
2251
2252 SpvId type = this->getType(callType);
2253 SpvId image = this->writeExpression(*arguments[0], out);
2254 SpvId coord = this->writeExpression(*arguments[1], out);
2255
2256 const Type& arg0Type = arguments[0]->type();
2257 SkASSERT(arg0Type.typeKind() == Type::TypeKind::kTexture);
2258
2259 switch (arg0Type.textureAccess()) {
2260 case Type::TextureAccess::kSample:
2261 this->writeInstruction(SpvOpImageFetch, type, result, image, coord,
2262 SpvImageOperandsLodMask,
2263 this->writeOpConstant(*fContext.fTypes.fInt, 0),
2264 out);
2265 break;
2266 case Type::TextureAccess::kRead:
2267 case Type::TextureAccess::kReadWrite:
2268 this->writeInstruction(SpvOpImageRead, type, result, image, coord, out);
2269 break;
2270 case Type::TextureAccess::kWrite:
2271 default:
2272 SkDEBUGFAIL("'textureRead' called on writeonly texture type");
2273 break;
2274 }
2275
2276 break;
2277 }
2278 case kTextureWrite_SpecialIntrinsic: {
2279 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2280 SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fUInt2));
2281 SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fHalf4));
2282
2283 SpvId image = this->writeExpression(*arguments[0], out);
2284 SpvId coord = this->writeExpression(*arguments[1], out);
2285 SpvId texel = this->writeExpression(*arguments[2], out);
2286
2287 this->writeInstruction(SpvOpImageWrite, image, coord, texel, out);
2288 break;
2289 }
2290 case kTextureWidth_SpecialIntrinsic:
2291 case kTextureHeight_SpecialIntrinsic: {
2292 SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2293 fCapabilities |= 1ULL << SpvCapabilityImageQuery;
2294
2295 SpvId dimsType = this->getType(*fContext.fTypes.fUInt2);
2296 SpvId dims = this->nextId(nullptr);
2297 SpvId image = this->writeExpression(*arguments[0], out);
2298 this->writeInstruction(SpvOpImageQuerySize, dimsType, dims, image, out);
2299
2300 SpvId type = this->getType(callType);
2301 int32_t index = (kind == kTextureWidth_SpecialIntrinsic) ? 0 : 1;
2302 this->writeInstruction(SpvOpCompositeExtract, type, result, dims, index, out);
2303 break;
2304 }
2305 case kMod_SpecialIntrinsic: {
2306 TArray<SpvId> args = this->vectorize(arguments, out);
2307 SkASSERT(args.size() == 2);
2308 const Type& operandType = arguments[0]->type();
2309 SpvOp_ op = pick_by_type(operandType, SpvOpFMod, SpvOpSMod, SpvOpUMod, SpvOpUndef);
2310 SkASSERT(op != SpvOpUndef);
2311 this->writeOpCode(op, 5, out);
2312 this->writeWord(this->getType(operandType), out);
2313 this->writeWord(result, out);
2314 this->writeWord(args[0], out);
2315 this->writeWord(args[1], out);
2316 break;
2317 }
2318 case kDFdy_SpecialIntrinsic: {
2319 SpvId fn = this->writeExpression(*arguments[0], out);
2320 this->writeOpCode(SpvOpDPdy, 4, out);
2321 this->writeWord(this->getType(callType), out);
2322 this->writeWord(result, out);
2323 this->writeWord(fn, out);
2324 if (!fProgram.fConfig->fSettings.fForceNoRTFlip) {
2325 this->addRTFlipUniform(c.fPosition);
2326 ComponentArray componentArray;
2327 for (int index = 0; index < callType.columns(); ++index) {
2328 componentArray.push_back(SwizzleComponent::Y);
2329 }
2330 SpvId rtFlipY = this->writeSwizzle(*this->identifier(SKSL_RTFLIP_NAME),
2331 componentArray, out);
2332 SpvId flipped = this->nextId(&callType);
2333 this->writeInstruction(SpvOpFMul, this->getType(callType), flipped, result,
2334 rtFlipY, out);
2335 result = flipped;
2336 }
2337 break;
2338 }
2339 case kClamp_SpecialIntrinsic: {
2340 TArray<SpvId> args = this->vectorize(arguments, out);
2341 SkASSERT(args.size() == 3);
2342 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
2343 GLSLstd450UClamp, args, out);
2344 break;
2345 }
2346 case kMax_SpecialIntrinsic: {
2347 TArray<SpvId> args = this->vectorize(arguments, out);
2348 SkASSERT(args.size() == 2);
2349 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMax, GLSLstd450SMax,
2350 GLSLstd450UMax, args, out);
2351 break;
2352 }
2353 case kMin_SpecialIntrinsic: {
2354 TArray<SpvId> args = this->vectorize(arguments, out);
2355 SkASSERT(args.size() == 2);
2356 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMin, GLSLstd450SMin,
2357 GLSLstd450UMin, args, out);
2358 break;
2359 }
2360 case kMix_SpecialIntrinsic: {
2361 TArray<SpvId> args = this->vectorize(arguments, out);
2362 SkASSERT(args.size() == 3);
2363 if (arguments[2]->type().componentType().isBoolean()) {
2364 // Use OpSelect to implement Boolean mix().
2365 SpvId falseId = this->writeExpression(*arguments[0], out);
2366 SpvId trueId = this->writeExpression(*arguments[1], out);
2367 SpvId conditionId = this->writeExpression(*arguments[2], out);
2368 this->writeInstruction(SpvOpSelect, this->getType(arguments[0]->type()), result,
2369 conditionId, trueId, falseId, out);
2370 } else {
2371 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMix, SpvOpUndef,
2372 SpvOpUndef, args, out);
2373 }
2374 break;
2375 }
2376 case kSaturate_SpecialIntrinsic: {
2377 SkASSERT(arguments.size() == 1);
2378 int width = arguments[0]->type().columns();
2379 STArray<3, SpvId> spvArgs{
2380 this->vectorize(*arguments[0], width, out),
2381 this->vectorize(*Literal::MakeFloat(fContext, Position(), /*value=*/0), width, out),
2382 this->vectorize(*Literal::MakeFloat(fContext, Position(), /*value=*/1), width, out),
2383 };
2384 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
2385 GLSLstd450UClamp, spvArgs, out);
2386 break;
2387 }
2388 case kSmoothStep_SpecialIntrinsic: {
2389 TArray<SpvId> args = this->vectorize(arguments, out);
2390 SkASSERT(args.size() == 3);
2391 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450SmoothStep, SpvOpUndef,
2392 SpvOpUndef, args, out);
2393 break;
2394 }
2395 case kStep_SpecialIntrinsic: {
2396 TArray<SpvId> args = this->vectorize(arguments, out);
2397 SkASSERT(args.size() == 2);
2398 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450Step, SpvOpUndef,
2399 SpvOpUndef, args, out);
2400 break;
2401 }
2402 case kMatrixCompMult_SpecialIntrinsic: {
2403 SkASSERT(arguments.size() == 2);
2404 SpvId lhs = this->writeExpression(*arguments[0], out);
2405 SpvId rhs = this->writeExpression(*arguments[1], out);
2406 result = this->writeComponentwiseMatrixBinary(callType, lhs, rhs, SpvOpFMul, out);
2407 break;
2408 }
2409 case kAtomicAdd_SpecialIntrinsic:
2410 case kAtomicLoad_SpecialIntrinsic:
2411 case kAtomicStore_SpecialIntrinsic:
2412 result = this->writeAtomicIntrinsic(c, kind, result, out);
2413 break;
2414 case kStorageBarrier_SpecialIntrinsic:
2415 case kWorkgroupBarrier_SpecialIntrinsic: {
2416 // Both barrier types operate in the workgroup execution and memory scope and differ
2417 // only in memory semantics. storageBarrier() is not a device-scope barrier.
2418 SpvId scopeId =
2419 this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)SpvScopeWorkgroup);
2420 int32_t memSemMask = (kind == kStorageBarrier_SpecialIntrinsic)
2421 ? SpvMemorySemanticsAcquireReleaseMask |
2422 SpvMemorySemanticsUniformMemoryMask
2423 : SpvMemorySemanticsAcquireReleaseMask |
2424 SpvMemorySemanticsWorkgroupMemoryMask;
2425 SpvId memorySemanticsId = this->writeOpConstant(*fContext.fTypes.fUInt, memSemMask);
2426 this->writeInstruction(SpvOpControlBarrier,
2427 scopeId, // execution scope
2428 scopeId, // memory scope
2429 memorySemanticsId,
2430 out);
2431 break;
2432 }
2433 }
2434 return result;
2435 }
2436
writeAtomicIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,SpvId resultId,OutputStream & out)2437 SpvId SPIRVCodeGenerator::writeAtomicIntrinsic(const FunctionCall& c,
2438 SpecialIntrinsic kind,
2439 SpvId resultId,
2440 OutputStream& out) {
2441 const ExpressionArray& arguments = c.arguments();
2442 SkASSERT(!arguments.empty());
2443
2444 std::unique_ptr<LValue> atomicPtr = this->getLValue(*arguments[0], out);
2445 SpvId atomicPtrId = atomicPtr->getPointer();
2446 if (atomicPtrId == NA) {
2447 SkDEBUGFAILF("atomic intrinsic expected a pointer argument: %s",
2448 arguments[0]->description().c_str());
2449 return NA;
2450 }
2451
2452 SpvId memoryScopeId = NA;
2453 {
2454 // In SkSL, the atomicUint type can only be declared as a workgroup variable or SSBO block
2455 // member. The two memory scopes that these map to are "workgroup" and "device",
2456 // respectively.
2457 SpvScope memoryScope;
2458 switch (atomicPtr->storageClass()) {
2459 case StorageClass::kUniform:
2460 case StorageClass::kStorageBuffer:
2461 // We encode storage buffers in the uniform address space (with the BufferBlock
2462 // decorator).
2463 memoryScope = SpvScopeDevice;
2464 break;
2465 case StorageClass::kWorkgroup:
2466 memoryScope = SpvScopeWorkgroup;
2467 break;
2468 default:
2469 SkDEBUGFAILF("atomic argument has invalid storage class: %d",
2470 get_storage_class_spv_id(atomicPtr->storageClass()));
2471 return NA;
2472 }
2473 memoryScopeId = this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)memoryScope);
2474 }
2475
2476 SpvId relaxedMemoryOrderId =
2477 this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)SpvMemorySemanticsMaskNone);
2478
2479 switch (kind) {
2480 case kAtomicAdd_SpecialIntrinsic:
2481 SkASSERT(arguments.size() == 2);
2482 this->writeInstruction(SpvOpAtomicIAdd,
2483 this->getType(c.type()),
2484 resultId,
2485 atomicPtrId,
2486 memoryScopeId,
2487 relaxedMemoryOrderId,
2488 this->writeExpression(*arguments[1], out),
2489 out);
2490 break;
2491 case kAtomicLoad_SpecialIntrinsic:
2492 SkASSERT(arguments.size() == 1);
2493 this->writeInstruction(SpvOpAtomicLoad,
2494 this->getType(c.type()),
2495 resultId,
2496 atomicPtrId,
2497 memoryScopeId,
2498 relaxedMemoryOrderId,
2499 out);
2500 break;
2501 case kAtomicStore_SpecialIntrinsic:
2502 SkASSERT(arguments.size() == 2);
2503 this->writeInstruction(SpvOpAtomicStore,
2504 atomicPtrId,
2505 memoryScopeId,
2506 relaxedMemoryOrderId,
2507 this->writeExpression(*arguments[1], out),
2508 out);
2509 break;
2510 default:
2511 SkUNREACHABLE;
2512 }
2513
2514 return resultId;
2515 }
2516
writeFunctionCallArgument(TArray<SpvId> & argumentList,const FunctionCall & call,int argIndex,std::vector<TempVar> * tempVars,const SkBitSet * specializedParams,OutputStream & out)2517 void SPIRVCodeGenerator::writeFunctionCallArgument(TArray<SpvId>& argumentList,
2518 const FunctionCall& call,
2519 int argIndex,
2520 std::vector<TempVar>* tempVars,
2521 const SkBitSet* specializedParams,
2522 OutputStream& out) {
2523 const FunctionDeclaration& funcDecl = call.function();
2524 const Expression& arg = *call.arguments()[argIndex];
2525 const Variable* param = funcDecl.parameters()[argIndex];
2526 bool paramIsSpecialized = specializedParams && specializedParams->test(argIndex);
2527 ModifierFlags paramFlags = param->modifierFlags();
2528
2529 // Ignore the argument since it is specialized, if fUseTextureSamplerPairs is true and this
2530 // argument is a sampler, handle ignoring the sampler below when generating the texture and
2531 // sampler pair arguments.
2532 if (paramIsSpecialized && !(param->type().isSampler() && fUseTextureSamplerPairs)) {
2533 return;
2534 }
2535
2536 if (arg.is<VariableReference>() && (arg.type().typeKind() == Type::TypeKind::kSampler ||
2537 arg.type().typeKind() == Type::TypeKind::kSeparateSampler ||
2538 arg.type().typeKind() == Type::TypeKind::kTexture)) {
2539 // Opaque handle (sampler/texture) arguments are always declared as pointers but never
2540 // stored in intermediates when calling user-defined functions.
2541 //
2542 // The case for intrinsics (which take opaque arguments by value) is handled above just like
2543 // regular pointers.
2544 //
2545 // See getFunctionParameterType for further explanation.
2546 const Variable* var = arg.as<VariableReference>().variable();
2547
2548 // In Dawn-mode the texture and sampler arguments are forwarded to the helper function.
2549 if (fUseTextureSamplerPairs && var->type().isSampler()) {
2550 if (const auto* p = fSynthesizedSamplerMap.find(var)) {
2551 SpvId* img = fVariableMap.find((*p)->fTexture.get());
2552 SkASSERT(img);
2553
2554 argumentList.push_back(*img);
2555
2556 if (!paramIsSpecialized) {
2557 SpvId* sampler = fVariableMap.find((*p)->fSampler.get());
2558 SkASSERT(sampler);
2559 argumentList.push_back(*sampler);
2560 }
2561 return;
2562 }
2563 SkDEBUGFAIL("sampler missing from fSynthesizedSamplerMap");
2564 }
2565
2566 SpvId* entry = fVariableMap.find(var);
2567 SkASSERTF(entry, "%s", arg.description().c_str());
2568 argumentList.push_back(*entry);
2569 return;
2570 }
2571 SkASSERT(!paramIsSpecialized);
2572
2573 // ID of temporary variable that we will use to hold this argument, or 0 if it is being
2574 // passed directly
2575 SpvId tmpVar = NA;
2576 // if we need a temporary var to store this argument, this is the value to store in the var
2577 SpvId tmpValueId = NA;
2578
2579 if (is_out(paramFlags)) {
2580 std::unique_ptr<LValue> lv = this->getLValue(arg, out);
2581 // We handle out params with a temp var that we copy back to the original variable at the
2582 // end of the call. GLSL guarantees that the original variable will be unchanged until the
2583 // end of the call, and also that out params are written back to their original variables in
2584 // a specific order (left-to-right), so it's unsafe to pass a pointer to the original value.
2585 if (is_in(paramFlags)) {
2586 tmpValueId = lv->load(out);
2587 }
2588 tmpVar = this->nextId(&arg.type());
2589 tempVars->push_back(TempVar{tmpVar, &arg.type(), std::move(lv)});
2590 } else if (funcDecl.isIntrinsic()) {
2591 // Unlike user function calls, non-out intrinsic arguments don't need pointer parameters.
2592 argumentList.push_back(this->writeExpression(arg, out));
2593 return;
2594 } else {
2595 // We always use pointer parameters when calling user functions.
2596 // See getFunctionParameterType for further explanation.
2597 tmpValueId = this->writeExpression(arg, out);
2598 tmpVar = this->nextId(nullptr);
2599 }
2600 this->writeInstruction(SpvOpVariable,
2601 this->getPointerType(arg.type(), StorageClass::kFunction),
2602 tmpVar,
2603 SpvStorageClassFunction,
2604 fVariableBuffer);
2605 if (tmpValueId != NA) {
2606 this->writeOpStore(StorageClass::kFunction, tmpVar, tmpValueId, out);
2607 }
2608 argumentList.push_back(tmpVar);
2609 }
2610
copyBackTempVars(const std::vector<TempVar> & tempVars,OutputStream & out)2611 void SPIRVCodeGenerator::copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out) {
2612 for (const TempVar& tempVar : tempVars) {
2613 SpvId load = this->nextId(tempVar.type);
2614 this->writeInstruction(SpvOpLoad, this->getType(*tempVar.type), load, tempVar.spvId, out);
2615 tempVar.lvalue->store(load, out);
2616 }
2617 }
2618
writeFunctionCall(const FunctionCall & c,OutputStream & out)2619 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
2620 // Handle intrinsics.
2621 const FunctionDeclaration& function = c.function();
2622 if (function.isIntrinsic() && !function.definition()) {
2623 return this->writeIntrinsicCall(c, out);
2624 }
2625
2626 // Look up this function (or its specialization, if any) in our map of function SpvIds.
2627 Analysis::SpecializationIndex specializationIndex = Analysis::FindSpecializationIndexForCall(
2628 c, fSpecializationInfo, fActiveSpecializationIndex);
2629 SpvId* entry = fFunctionMap.find({&function, specializationIndex});
2630 if (!entry) {
2631 fContext.fErrors->error(c.fPosition, "function '" + function.description() +
2632 "' is not defined");
2633 return NA;
2634 }
2635
2636 // If we are calling a specialized function, we need to gather the specialized parameters
2637 // so we can remove them from the argument list.
2638 SkBitSet specializedParams =
2639 Analysis::FindSpecializedParametersForFunction(c.function(), fSpecializationInfo);
2640
2641 // Temp variables are used to write back out-parameters after the function call is complete.
2642 const ExpressionArray& arguments = c.arguments();
2643 std::vector<TempVar> tempVars;
2644 TArray<SpvId> argumentIds;
2645 argumentIds.reserve_exact(arguments.size());
2646 for (int i = 0; i < arguments.size(); i++) {
2647 this->writeFunctionCallArgument(argumentIds, c, i, &tempVars, &specializedParams, out);
2648 }
2649 SpvId result = this->nextId(nullptr);
2650 this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t)argumentIds.size(), out);
2651 this->writeWord(this->getType(c.type()), out);
2652 this->writeWord(result, out);
2653 this->writeWord(*entry, out);
2654 for (SpvId id : argumentIds) {
2655 this->writeWord(id, out);
2656 }
2657 // Now that the call is complete, we copy temp out-variables back to their real lvalues.
2658 this->copyBackTempVars(tempVars, out);
2659 return result;
2660 }
2661
castScalarToType(SpvId inputExprId,const Type & inputType,const Type & outputType,OutputStream & out)2662 SpvId SPIRVCodeGenerator::castScalarToType(SpvId inputExprId,
2663 const Type& inputType,
2664 const Type& outputType,
2665 OutputStream& out) {
2666 if (outputType.isFloat()) {
2667 return this->castScalarToFloat(inputExprId, inputType, outputType, out);
2668 }
2669 if (outputType.isSigned()) {
2670 return this->castScalarToSignedInt(inputExprId, inputType, outputType, out);
2671 }
2672 if (outputType.isUnsigned()) {
2673 return this->castScalarToUnsignedInt(inputExprId, inputType, outputType, out);
2674 }
2675 if (outputType.isBoolean()) {
2676 return this->castScalarToBoolean(inputExprId, inputType, outputType, out);
2677 }
2678
2679 fContext.fErrors->error(Position(), "unsupported cast: " + inputType.description() + " to " +
2680 outputType.description());
2681 return inputExprId;
2682 }
2683
castScalarToFloat(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2684 SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType,
2685 const Type& outputType, OutputStream& out) {
2686 // Casting a float to float is a no-op.
2687 if (inputType.isFloat()) {
2688 return inputId;
2689 }
2690
2691 // Given the input type, generate the appropriate instruction to cast to float.
2692 SpvId result = this->nextId(&outputType);
2693 if (inputType.isBoolean()) {
2694 // Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0.
2695 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fFloat);
2696 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2697 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2698 inputId, oneID, zeroID, out);
2699 } else if (inputType.isSigned()) {
2700 this->writeInstruction(SpvOpConvertSToF, this->getType(outputType), result, inputId, out);
2701 } else if (inputType.isUnsigned()) {
2702 this->writeInstruction(SpvOpConvertUToF, this->getType(outputType), result, inputId, out);
2703 } else {
2704 SkDEBUGFAILF("unsupported type for float typecast: %s", inputType.description().c_str());
2705 return NA;
2706 }
2707 return result;
2708 }
2709
castScalarToSignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2710 SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& inputType,
2711 const Type& outputType, OutputStream& out) {
2712 // Casting a signed int to signed int is a no-op.
2713 if (inputType.isSigned()) {
2714 return inputId;
2715 }
2716
2717 // Given the input type, generate the appropriate instruction to cast to signed int.
2718 SpvId result = this->nextId(&outputType);
2719 if (inputType.isBoolean()) {
2720 // Use OpSelect to convert the boolean argument to a literal 1 or 0.
2721 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fInt);
2722 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
2723 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2724 inputId, oneID, zeroID, out);
2725 } else if (inputType.isFloat()) {
2726 this->writeInstruction(SpvOpConvertFToS, this->getType(outputType), result, inputId, out);
2727 } else if (inputType.isUnsigned()) {
2728 this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
2729 } else {
2730 SkDEBUGFAILF("unsupported type for signed int typecast: %s",
2731 inputType.description().c_str());
2732 return NA;
2733 }
2734 return result;
2735 }
2736
castScalarToUnsignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2737 SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inputType,
2738 const Type& outputType, OutputStream& out) {
2739 // Casting an unsigned int to unsigned int is a no-op.
2740 if (inputType.isUnsigned()) {
2741 return inputId;
2742 }
2743
2744 // Given the input type, generate the appropriate instruction to cast to unsigned int.
2745 SpvId result = this->nextId(&outputType);
2746 if (inputType.isBoolean()) {
2747 // Use OpSelect to convert the boolean argument to a literal 1u or 0u.
2748 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fUInt);
2749 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
2750 this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2751 inputId, oneID, zeroID, out);
2752 } else if (inputType.isFloat()) {
2753 this->writeInstruction(SpvOpConvertFToU, this->getType(outputType), result, inputId, out);
2754 } else if (inputType.isSigned()) {
2755 this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
2756 } else {
2757 SkDEBUGFAILF("unsupported type for unsigned int typecast: %s",
2758 inputType.description().c_str());
2759 return NA;
2760 }
2761 return result;
2762 }
2763
castScalarToBoolean(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2764 SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputType,
2765 const Type& outputType, OutputStream& out) {
2766 // Casting a bool to bool is a no-op.
2767 if (inputType.isBoolean()) {
2768 return inputId;
2769 }
2770
2771 // Given the input type, generate the appropriate instruction to cast to bool.
2772 SpvId result = this->nextId(nullptr);
2773 if (inputType.isSigned()) {
2774 // Synthesize a boolean result by comparing the input against a signed zero literal.
2775 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
2776 this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
2777 inputId, zeroID, out);
2778 } else if (inputType.isUnsigned()) {
2779 // Synthesize a boolean result by comparing the input against an unsigned zero literal.
2780 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
2781 this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
2782 inputId, zeroID, out);
2783 } else if (inputType.isFloat()) {
2784 // Synthesize a boolean result by comparing the input against a floating-point zero literal.
2785 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2786 this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result,
2787 inputId, zeroID, out);
2788 } else {
2789 SkDEBUGFAILF("unsupported type for boolean typecast: %s", inputType.description().c_str());
2790 return NA;
2791 }
2792 return result;
2793 }
2794
writeMatrixCopy(SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)2795 SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType,
2796 OutputStream& out) {
2797 SkASSERT(srcType.isMatrix());
2798 SkASSERT(dstType.isMatrix());
2799 SkASSERT(srcType.componentType().matches(dstType.componentType()));
2800 const Type& srcColumnType = srcType.componentType().toCompound(fContext, srcType.rows(), 1);
2801 const Type& dstColumnType = dstType.componentType().toCompound(fContext, dstType.rows(), 1);
2802 SkASSERT(dstType.componentType().isFloat());
2803 SpvId dstColumnTypeId = this->getType(dstColumnType);
2804 const SpvId zeroId = this->writeLiteral(0.0, dstType.componentType());
2805 const SpvId oneId = this->writeLiteral(1.0, dstType.componentType());
2806
2807 STArray<4, SpvId> columns;
2808 for (int i = 0; i < dstType.columns(); i++) {
2809 if (i < srcType.columns()) {
2810 // we're still inside the src matrix, copy the column
2811 SpvId srcColumn = this->writeOpCompositeExtract(srcColumnType, src, i, out);
2812 SpvId dstColumn;
2813 if (srcType.rows() == dstType.rows()) {
2814 // columns are equal size, don't need to do anything
2815 dstColumn = srcColumn;
2816 }
2817 else if (dstType.rows() > srcType.rows()) {
2818 // dst column is bigger, need to zero-pad it
2819 STArray<4, SpvId> values;
2820 values.push_back(srcColumn);
2821 for (int j = srcType.rows(); j < dstType.rows(); ++j) {
2822 values.push_back((i == j) ? oneId : zeroId);
2823 }
2824 dstColumn = this->writeOpCompositeConstruct(dstColumnType, values, out);
2825 }
2826 else {
2827 // dst column is smaller, need to swizzle the src column
2828 dstColumn = this->nextId(&dstType);
2829 this->writeOpCode(SpvOpVectorShuffle, 5 + dstType.rows(), out);
2830 this->writeWord(dstColumnTypeId, out);
2831 this->writeWord(dstColumn, out);
2832 this->writeWord(srcColumn, out);
2833 this->writeWord(srcColumn, out);
2834 for (int j = 0; j < dstType.rows(); j++) {
2835 this->writeWord(j, out);
2836 }
2837 }
2838 columns.push_back(dstColumn);
2839 } else {
2840 // we're past the end of the src matrix, need to synthesize an identity-matrix column
2841 STArray<4, SpvId> values;
2842 for (int j = 0; j < dstType.rows(); ++j) {
2843 values.push_back((i == j) ? oneId : zeroId);
2844 }
2845 columns.push_back(this->writeOpCompositeConstruct(dstColumnType, values, out));
2846 }
2847 }
2848
2849 return this->writeOpCompositeConstruct(dstType, columns, out);
2850 }
2851
addColumnEntry(const Type & columnType,TArray<SpvId> * currentColumn,TArray<SpvId> * columnIds,int rows,SpvId entry,OutputStream & out)2852 void SPIRVCodeGenerator::addColumnEntry(const Type& columnType,
2853 TArray<SpvId>* currentColumn,
2854 TArray<SpvId>* columnIds,
2855 int rows,
2856 SpvId entry,
2857 OutputStream& out) {
2858 SkASSERT(currentColumn->size() < rows);
2859 currentColumn->push_back(entry);
2860 if (currentColumn->size() == rows) {
2861 // Synthesize this column into a vector.
2862 SpvId columnId = this->writeOpCompositeConstruct(columnType, *currentColumn, out);
2863 columnIds->push_back(columnId);
2864 currentColumn->clear();
2865 }
2866 }
2867
writeMatrixConstructor(const ConstructorCompound & c,OutputStream & out)2868 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out) {
2869 const Type& type = c.type();
2870 SkASSERT(type.isMatrix());
2871 SkASSERT(!c.arguments().empty());
2872 const Type& arg0Type = c.arguments()[0]->type();
2873 // go ahead and write the arguments so we don't try to write new instructions in the middle of
2874 // an instruction
2875 STArray<16, SpvId> arguments;
2876 for (const std::unique_ptr<Expression>& arg : c.arguments()) {
2877 arguments.push_back(this->writeExpression(*arg, out));
2878 }
2879
2880 if (arguments.size() == 1 && arg0Type.isVector()) {
2881 // Special-case handling of float4 -> mat2x2.
2882 SkASSERT(type.rows() == 2 && type.columns() == 2);
2883 SkASSERT(arg0Type.columns() == 4);
2884 SpvId v[4];
2885 for (int i = 0; i < 4; ++i) {
2886 v[i] = this->writeOpCompositeExtract(type.componentType(), arguments[0], i, out);
2887 }
2888 const Type& vecType = type.columnType(fContext);
2889 SpvId v0v1 = this->writeOpCompositeConstruct(vecType, {v[0], v[1]}, out);
2890 SpvId v2v3 = this->writeOpCompositeConstruct(vecType, {v[2], v[3]}, out);
2891 return this->writeOpCompositeConstruct(type, {v0v1, v2v3}, out);
2892 }
2893
2894 int rows = type.rows();
2895 const Type& columnType = type.columnType(fContext);
2896 // SpvIds of completed columns of the matrix.
2897 STArray<4, SpvId> columnIds;
2898 // SpvIds of scalars we have written to the current column so far.
2899 STArray<4, SpvId> currentColumn;
2900 for (int i = 0; i < arguments.size(); i++) {
2901 const Type& argType = c.arguments()[i]->type();
2902 if (currentColumn.empty() && argType.isVector() && argType.columns() == rows) {
2903 // This vector is a complete matrix column by itself and can be used as-is.
2904 columnIds.push_back(arguments[i]);
2905 } else if (argType.columns() == 1) {
2906 // This argument is a lone scalar and can be added to the current column as-is.
2907 this->addColumnEntry(columnType, ¤tColumn, &columnIds, rows, arguments[i], out);
2908 } else {
2909 // This argument needs to be decomposed into its constituent scalars.
2910 for (int j = 0; j < argType.columns(); ++j) {
2911 SpvId swizzle = this->writeOpCompositeExtract(argType.componentType(),
2912 arguments[i], j, out);
2913 this->addColumnEntry(columnType, ¤tColumn, &columnIds, rows, swizzle, out);
2914 }
2915 }
2916 }
2917 SkASSERT(columnIds.size() == type.columns());
2918 return this->writeOpCompositeConstruct(type, columnIds, out);
2919 }
2920
writeConstructorCompound(const ConstructorCompound & c,OutputStream & out)2921 SpvId SPIRVCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
2922 OutputStream& out) {
2923 return c.type().isMatrix() ? this->writeMatrixConstructor(c, out)
2924 : this->writeVectorConstructor(c, out);
2925 }
2926
writeVectorConstructor(const ConstructorCompound & c,OutputStream & out)2927 SpvId SPIRVCodeGenerator::writeVectorConstructor(const ConstructorCompound& c, OutputStream& out) {
2928 const Type& type = c.type();
2929 const Type& componentType = type.componentType();
2930 SkASSERT(type.isVector());
2931
2932 STArray<4, SpvId> arguments;
2933 for (int i = 0; i < c.arguments().size(); i++) {
2934 const Type& argType = c.arguments()[i]->type();
2935 SkASSERT(componentType.numberKind() == argType.componentType().numberKind());
2936
2937 SpvId arg = this->writeExpression(*c.arguments()[i], out);
2938 if (argType.isMatrix()) {
2939 // CompositeConstruct cannot take a 2x2 matrix as an input, so we need to extract out
2940 // each scalar separately.
2941 SkASSERT(argType.rows() == 2);
2942 SkASSERT(argType.columns() == 2);
2943 for (int j = 0; j < 4; ++j) {
2944 arguments.push_back(this->writeOpCompositeExtract(componentType, arg,
2945 j / 2, j % 2, out));
2946 }
2947 } else if (argType.isVector()) {
2948 // There's a bug in the Intel Vulkan driver where OpCompositeConstruct doesn't handle
2949 // vector arguments at all, so we always extract each vector component and pass them
2950 // into OpCompositeConstruct individually.
2951 for (int j = 0; j < argType.columns(); j++) {
2952 arguments.push_back(this->writeOpCompositeExtract(componentType, arg, j, out));
2953 }
2954 } else {
2955 arguments.push_back(arg);
2956 }
2957 }
2958
2959 return this->writeOpCompositeConstruct(type, arguments, out);
2960 }
2961
writeConstructorSplat(const ConstructorSplat & c,OutputStream & out)2962 SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) {
2963 // Write the splat argument as a scalar, then splat it.
2964 SpvId argument = this->writeExpression(*c.argument(), out);
2965 return this->splat(c.type(), argument, out);
2966 }
2967
writeCompositeConstructor(const AnyConstructor & c,OutputStream & out)2968 SpvId SPIRVCodeGenerator::writeCompositeConstructor(const AnyConstructor& c, OutputStream& out) {
2969 SkASSERT(c.type().isArray() || c.type().isStruct());
2970 auto ctorArgs = c.argumentSpan();
2971
2972 STArray<4, SpvId> arguments;
2973 for (const std::unique_ptr<Expression>& arg : ctorArgs) {
2974 arguments.push_back(this->writeExpression(*arg, out));
2975 }
2976
2977 return this->writeOpCompositeConstruct(c.type(), arguments, out);
2978 }
2979
writeConstructorScalarCast(const ConstructorScalarCast & c,OutputStream & out)2980 SpvId SPIRVCodeGenerator::writeConstructorScalarCast(const ConstructorScalarCast& c,
2981 OutputStream& out) {
2982 const Type& type = c.type();
2983 if (type.componentType().numberKind() == c.argument()->type().componentType().numberKind()) {
2984 return this->writeExpression(*c.argument(), out);
2985 }
2986
2987 const Expression& ctorExpr = *c.argument();
2988 SpvId expressionId = this->writeExpression(ctorExpr, out);
2989 return this->castScalarToType(expressionId, ctorExpr.type(), type, out);
2990 }
2991
writeConstructorCompoundCast(const ConstructorCompoundCast & c,OutputStream & out)2992 SpvId SPIRVCodeGenerator::writeConstructorCompoundCast(const ConstructorCompoundCast& c,
2993 OutputStream& out) {
2994 const Type& ctorType = c.type();
2995 const Type& argType = c.argument()->type();
2996 SkASSERT(ctorType.isVector() || ctorType.isMatrix());
2997
2998 // Write the composite that we are casting. If the actual type matches, we are done.
2999 SpvId compositeId = this->writeExpression(*c.argument(), out);
3000 if (ctorType.componentType().numberKind() == argType.componentType().numberKind()) {
3001 return compositeId;
3002 }
3003
3004 // writeMatrixCopy can cast matrices to a different type.
3005 if (ctorType.isMatrix()) {
3006 return this->writeMatrixCopy(compositeId, argType, ctorType, out);
3007 }
3008
3009 // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to extract the
3010 // components and convert each one manually.
3011 const Type& srcType = argType.componentType();
3012 const Type& dstType = ctorType.componentType();
3013
3014 STArray<4, SpvId> arguments;
3015 for (int index = 0; index < argType.columns(); ++index) {
3016 SpvId componentId = this->writeOpCompositeExtract(srcType, compositeId, index, out);
3017 arguments.push_back(this->castScalarToType(componentId, srcType, dstType, out));
3018 }
3019
3020 return this->writeOpCompositeConstruct(ctorType, arguments, out);
3021 }
3022
writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c,OutputStream & out)3023 SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
3024 OutputStream& out) {
3025 const Type& type = c.type();
3026 SkASSERT(type.isMatrix());
3027 SkASSERT(c.argument()->type().isScalar());
3028
3029 // Write out the scalar argument.
3030 SpvId diagonal = this->writeExpression(*c.argument(), out);
3031
3032 // Build the diagonal matrix.
3033 SpvId zeroId = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
3034
3035 const Type& vecType = type.columnType(fContext);
3036 STArray<4, SpvId> columnIds;
3037 STArray<4, SpvId> arguments;
3038 arguments.resize(type.rows());
3039 for (int column = 0; column < type.columns(); column++) {
3040 for (int row = 0; row < type.rows(); row++) {
3041 arguments[row] = (row == column) ? diagonal : zeroId;
3042 }
3043 columnIds.push_back(this->writeOpCompositeConstruct(vecType, arguments, out));
3044 }
3045 return this->writeOpCompositeConstruct(type, columnIds, out);
3046 }
3047
writeConstructorMatrixResize(const ConstructorMatrixResize & c,OutputStream & out)3048 SpvId SPIRVCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
3049 OutputStream& out) {
3050 // Write the input matrix.
3051 SpvId argument = this->writeExpression(*c.argument(), out);
3052
3053 // Use matrix-copy to resize the input matrix to its new size.
3054 return this->writeMatrixCopy(argument, c.argument()->type(), c.type(), out);
3055 }
3056
get_storage_class_for_global_variable(const Variable & var,StorageClass fallbackStorageClass)3057 static StorageClass get_storage_class_for_global_variable(
3058 const Variable& var, StorageClass fallbackStorageClass) {
3059 SkASSERT(var.storage() == Variable::Storage::kGlobal);
3060
3061 if (var.type().typeKind() == Type::TypeKind::kSampler ||
3062 var.type().typeKind() == Type::TypeKind::kSeparateSampler ||
3063 var.type().typeKind() == Type::TypeKind::kTexture) {
3064 return StorageClass::kUniformConstant;
3065 }
3066
3067 const Layout& layout = var.layout();
3068 ModifierFlags flags = var.modifierFlags();
3069 if (flags & ModifierFlag::kIn) {
3070 SkASSERT(!(layout.fFlags & LayoutFlag::kPushConstant));
3071 return StorageClass::kInput;
3072 }
3073 if (flags & ModifierFlag::kOut) {
3074 SkASSERT(!(layout.fFlags & LayoutFlag::kPushConstant));
3075 return StorageClass::kOutput;
3076 }
3077 if (flags.isUniform()) {
3078 if (layout.fFlags & LayoutFlag::kPushConstant) {
3079 return StorageClass::kPushConstant;
3080 }
3081 return StorageClass::kUniform;
3082 }
3083 if (flags.isBuffer()) {
3084 return StorageClass::kStorageBuffer;
3085 }
3086 if (flags.isWorkgroup()) {
3087 return StorageClass::kWorkgroup;
3088 }
3089 return fallbackStorageClass;
3090 }
3091
getStorageClass(const Expression & expr)3092 StorageClass SPIRVCodeGenerator::getStorageClass(const Expression& expr) {
3093 switch (expr.kind()) {
3094 case Expression::Kind::kVariableReference: {
3095 const Variable& var = *expr.as<VariableReference>().variable();
3096 if (fActiveSpecialization) {
3097 const Expression** specializedExpr = fActiveSpecialization->find(&var);
3098 if (specializedExpr && (*specializedExpr)->is<FieldAccess>()) {
3099 return this->getStorageClass(**specializedExpr);
3100 }
3101 }
3102 if (var.storage() != Variable::Storage::kGlobal) {
3103 return StorageClass::kFunction;
3104 }
3105 return get_storage_class_for_global_variable(var, StorageClass::kPrivate);
3106 }
3107 case Expression::Kind::kFieldAccess:
3108 return this->getStorageClass(*expr.as<FieldAccess>().base());
3109 case Expression::Kind::kIndex:
3110 return this->getStorageClass(*expr.as<IndexExpression>().base());
3111 default:
3112 return StorageClass::kFunction;
3113 }
3114 }
3115
getAccessChain(const Expression & expr,OutputStream & out)3116 TArray<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
3117 switch (expr.kind()) {
3118 case Expression::Kind::kIndex: {
3119 const IndexExpression& indexExpr = expr.as<IndexExpression>();
3120 if (indexExpr.base()->is<Swizzle>()) {
3121 // Access chains don't directly support dynamically indexing into a swizzle, but we
3122 // can rewrite them into a supported form.
3123 return this->getAccessChain(*Transform::RewriteIndexedSwizzle(fContext, indexExpr),
3124 out);
3125 }
3126 // All other index-expressions can be represented as typical access chains.
3127 TArray<SpvId> chain = this->getAccessChain(*indexExpr.base(), out);
3128 chain.push_back(this->writeExpression(*indexExpr.index(), out));
3129 return chain;
3130 }
3131 case Expression::Kind::kFieldAccess: {
3132 const FieldAccess& fieldExpr = expr.as<FieldAccess>();
3133 TArray<SpvId> chain = this->getAccessChain(*fieldExpr.base(), out);
3134 chain.push_back(this->writeLiteral(fieldExpr.fieldIndex(), *fContext.fTypes.fInt));
3135 return chain;
3136 }
3137 case Expression::Kind::kVariableReference: {
3138 if (fActiveSpecialization) {
3139 const Expression** specializedFieldIndex =
3140 fActiveSpecialization->find(expr.as<VariableReference>().variable());
3141 if (specializedFieldIndex && (*specializedFieldIndex)->is<FieldAccess>()) {
3142 return this->getAccessChain(**specializedFieldIndex, out);
3143 }
3144 }
3145 [[fallthrough]];
3146 }
3147 default: {
3148 SpvId id = this->getLValue(expr, out)->getPointer();
3149 SkASSERT(id != NA);
3150 return TArray<SpvId>{id};
3151 }
3152 }
3153 SkUNREACHABLE;
3154 }
3155
3156 class PointerLValue : public SPIRVCodeGenerator::LValue {
3157 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,bool isMemoryObject,SpvId type,SPIRVCodeGenerator::Precision precision,StorageClass storageClass)3158 PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, bool isMemoryObject, SpvId type,
3159 SPIRVCodeGenerator::Precision precision, StorageClass storageClass)
3160 : fGen(gen)
3161 , fPointer(pointer)
3162 , fIsMemoryObject(isMemoryObject)
3163 , fType(type)
3164 , fPrecision(precision)
3165 , fStorageClass(storageClass) {}
3166
getPointer()3167 SpvId getPointer() override {
3168 return fPointer;
3169 }
3170
isMemoryObjectPointer() const3171 bool isMemoryObjectPointer() const override {
3172 return fIsMemoryObject;
3173 }
3174
storageClass() const3175 StorageClass storageClass() const override {
3176 return fStorageClass;
3177 }
3178
load(OutputStream & out)3179 SpvId load(OutputStream& out) override {
3180 return fGen.writeOpLoad(fType, fPrecision, fPointer, out);
3181 }
3182
store(SpvId value,OutputStream & out)3183 void store(SpvId value, OutputStream& out) override {
3184 if (!fIsMemoryObject) {
3185 // We are going to write into an access chain; this could represent one component of a
3186 // vector, or one element of an array. This has the potential to invalidate other,
3187 // *unknown* elements of our store cache. (e.g. if the store cache holds `%50 = myVec4`,
3188 // and we store `%60 = myVec4.z`, this invalidates the cached value for %50.) To avoid
3189 // relying on stale data, reset the store cache entirely when this happens.
3190 fGen.fStoreCache.reset();
3191 }
3192
3193 fGen.writeOpStore(fStorageClass, fPointer, value, out);
3194 }
3195
3196 private:
3197 SPIRVCodeGenerator& fGen;
3198 const SpvId fPointer;
3199 const bool fIsMemoryObject;
3200 const SpvId fType;
3201 const SPIRVCodeGenerator::Precision fPrecision;
3202 const StorageClass fStorageClass;
3203 };
3204
3205 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
3206 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const ComponentArray & components,const Type & baseType,const Type & swizzleType,StorageClass storageClass)3207 SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const ComponentArray& components,
3208 const Type& baseType, const Type& swizzleType, StorageClass storageClass)
3209 : fGen(gen)
3210 , fVecPointer(vecPointer)
3211 , fComponents(components)
3212 , fBaseType(&baseType)
3213 , fSwizzleType(&swizzleType)
3214 , fStorageClass(storageClass) {}
3215
applySwizzle(const ComponentArray & components,const Type & newType)3216 bool applySwizzle(const ComponentArray& components, const Type& newType) override {
3217 ComponentArray updatedSwizzle;
3218 for (int8_t component : components) {
3219 if (component < 0 || component >= fComponents.size()) {
3220 SkDEBUGFAILF("swizzle accessed nonexistent component %d", (int)component);
3221 return false;
3222 }
3223 updatedSwizzle.push_back(fComponents[component]);
3224 }
3225 fComponents = updatedSwizzle;
3226 fSwizzleType = &newType;
3227 return true;
3228 }
3229
storageClass() const3230 StorageClass storageClass() const override {
3231 return fStorageClass;
3232 }
3233
load(OutputStream & out)3234 SpvId load(OutputStream& out) override {
3235 SpvId base = fGen.nextId(fBaseType);
3236 fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
3237 SpvId result = fGen.nextId(fBaseType);
3238 fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
3239 fGen.writeWord(fGen.getType(*fSwizzleType), out);
3240 fGen.writeWord(result, out);
3241 fGen.writeWord(base, out);
3242 fGen.writeWord(base, out);
3243 for (int component : fComponents) {
3244 fGen.writeWord(component, out);
3245 }
3246 return result;
3247 }
3248
store(SpvId value,OutputStream & out)3249 void store(SpvId value, OutputStream& out) override {
3250 // use OpVectorShuffle to mix and match the vector components. We effectively create
3251 // a virtual vector out of the concatenation of the left and right vectors, and then
3252 // select components from this virtual vector to make the result vector. For
3253 // instance, given:
3254 // float3L = ...;
3255 // float3R = ...;
3256 // L.xz = R.xy;
3257 // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
3258 // our result vector to look like (R.x, L.y, R.y), so we need to select indices
3259 // (3, 1, 4).
3260 SpvId base = fGen.nextId(fBaseType);
3261 fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
3262 SpvId shuffle = fGen.nextId(fBaseType);
3263 fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType->columns(), out);
3264 fGen.writeWord(fGen.getType(*fBaseType), out);
3265 fGen.writeWord(shuffle, out);
3266 fGen.writeWord(base, out);
3267 fGen.writeWord(value, out);
3268 for (int i = 0; i < fBaseType->columns(); i++) {
3269 // current offset into the virtual vector, defaults to pulling the unmodified
3270 // value from the left side
3271 int offset = i;
3272 // check to see if we are writing this component
3273 for (int j = 0; j < fComponents.size(); j++) {
3274 if (fComponents[j] == i) {
3275 // we're writing to this component, so adjust the offset to pull from
3276 // the correct component of the right side instead of preserving the
3277 // value from the left
3278 offset = (int) (j + fBaseType->columns());
3279 break;
3280 }
3281 }
3282 fGen.writeWord(offset, out);
3283 }
3284 fGen.writeOpStore(fStorageClass, fVecPointer, shuffle, out);
3285 }
3286
3287 private:
3288 SPIRVCodeGenerator& fGen;
3289 const SpvId fVecPointer;
3290 ComponentArray fComponents;
3291 const Type* fBaseType;
3292 const Type* fSwizzleType;
3293 const StorageClass fStorageClass;
3294 };
3295
findUniformFieldIndex(const Variable & var) const3296 int SPIRVCodeGenerator::findUniformFieldIndex(const Variable& var) const {
3297 int* fieldIndex = fTopLevelUniformMap.find(&var);
3298 return fieldIndex ? *fieldIndex : -1;
3299 }
3300
getLValue(const Expression & expr,OutputStream & out)3301 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
3302 OutputStream& out) {
3303 const Type& type = expr.type();
3304 Precision precision = type.highPrecision() ? Precision::kDefault : Precision::kRelaxed;
3305 switch (expr.kind()) {
3306 case Expression::Kind::kVariableReference: {
3307 const Variable& var = *expr.as<VariableReference>().variable();
3308 int uniformIdx = this->findUniformFieldIndex(var);
3309 if (uniformIdx >= 0) {
3310 // Access uniforms via an AccessChain into the uniform-buffer struct.
3311 SpvId memberId = this->nextId(nullptr);
3312 SpvId typeId = this->getPointerType(type, StorageClass::kUniform);
3313 SpvId uniformIdxId = this->writeLiteral((double)uniformIdx, *fContext.fTypes.fInt);
3314 this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId,
3315 uniformIdxId, out);
3316 return std::make_unique<PointerLValue>(
3317 *this,
3318 memberId,
3319 /*isMemoryObjectPointer=*/true,
3320 this->getType(type, kDefaultTypeLayout, this->memoryLayoutForVariable(var)),
3321 precision,
3322 StorageClass::kUniform);
3323 }
3324
3325 SpvId* entry = fVariableMap.find(&var);
3326 SkASSERTF(entry, "%s", expr.description().c_str());
3327
3328 if (var.layout().fBuiltin == SK_SAMPLEMASKIN_BUILTIN ||
3329 var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3330 // Access sk_SampleMask and sk_SampleMaskIn via an array access, since Vulkan
3331 // represents sample masks as an array of uints.
3332 StorageClass storageClass =
3333 get_storage_class_for_global_variable(var, StorageClass::kPrivate);
3334 SkASSERT(storageClass != StorageClass::kPrivate);
3335 SkASSERT(type.matches(*fContext.fTypes.fUInt));
3336
3337 SpvId accessId = this->nextId(nullptr);
3338 SpvId typeId = this->getPointerType(type, storageClass);
3339 SpvId indexId = this->writeLiteral(0.0, *fContext.fTypes.fInt);
3340 this->writeInstruction(SpvOpAccessChain, typeId, accessId, *entry, indexId, out);
3341 return std::make_unique<PointerLValue>(*this,
3342 accessId,
3343 /*isMemoryObjectPointer=*/true,
3344 this->getType(type),
3345 precision,
3346 storageClass);
3347 }
3348
3349 SpvId typeId = this->getType(type, var.layout(), this->memoryLayoutForVariable(var));
3350 return std::make_unique<PointerLValue>(*this, *entry,
3351 /*isMemoryObjectPointer=*/true,
3352 typeId, precision, this->getStorageClass(expr));
3353 }
3354 case Expression::Kind::kIndex: // fall through
3355 case Expression::Kind::kFieldAccess: {
3356 TArray<SpvId> chain = this->getAccessChain(expr, out);
3357 SpvId member = this->nextId(nullptr);
3358 StorageClass storageClass = this->getStorageClass(expr);
3359 this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
3360 this->writeWord(this->getPointerType(type, storageClass), out);
3361 this->writeWord(member, out);
3362 for (SpvId idx : chain) {
3363 this->writeWord(idx, out);
3364 }
3365 return std::make_unique<PointerLValue>(
3366 *this,
3367 member,
3368 /*isMemoryObjectPointer=*/false,
3369 this->getType(type,
3370 kDefaultTypeLayout,
3371 this->memoryLayoutForStorageClass(storageClass)),
3372 precision,
3373 storageClass);
3374 }
3375 case Expression::Kind::kSwizzle: {
3376 const Swizzle& swizzle = expr.as<Swizzle>();
3377 std::unique_ptr<LValue> lvalue = this->getLValue(*swizzle.base(), out);
3378 if (lvalue->applySwizzle(swizzle.components(), type)) {
3379 return lvalue;
3380 }
3381 SpvId base = lvalue->getPointer();
3382 if (base == NA) {
3383 fContext.fErrors->error(swizzle.fPosition,
3384 "unable to retrieve lvalue from swizzle");
3385 }
3386 StorageClass storageClass = this->getStorageClass(*swizzle.base());
3387 if (swizzle.components().size() == 1) {
3388 SpvId member = this->nextId(nullptr);
3389 SpvId typeId = this->getPointerType(type, storageClass);
3390 SpvId indexId = this->writeLiteral(swizzle.components()[0], *fContext.fTypes.fInt);
3391 this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out);
3392 return std::make_unique<PointerLValue>(*this, member,
3393 /*isMemoryObjectPointer=*/false,
3394 this->getType(type),
3395 precision, storageClass);
3396 } else {
3397 return std::make_unique<SwizzleLValue>(*this, base, swizzle.components(),
3398 swizzle.base()->type(), type, storageClass);
3399 }
3400 }
3401 default: {
3402 // expr isn't actually an lvalue, create a placeholder variable for it. This case
3403 // happens due to the need to store values in temporary variables during function
3404 // calls (see comments in getFunctionParameterType); erroneous uses of rvalues as
3405 // lvalues should have been caught before code generation.
3406 //
3407 // This is with the exception of opaque handle types (textures/samplers) which are
3408 // always defined as UniformConstant pointers and don't need to be explicitly stored
3409 // into a temporary (which is handled explicitly in writeFunctionCallArgument).
3410 SpvId result = this->nextId(nullptr);
3411 SpvId pointerType = this->getPointerType(type, StorageClass::kFunction);
3412 this->writeInstruction(SpvOpVariable, pointerType, result, SpvStorageClassFunction,
3413 fVariableBuffer);
3414 this->writeOpStore(StorageClass::kFunction, result, this->writeExpression(expr, out),
3415 out);
3416 return std::make_unique<PointerLValue>(*this, result, /*isMemoryObjectPointer=*/true,
3417 this->getType(type), precision,
3418 StorageClass::kFunction);
3419 }
3420 }
3421 }
3422
identifier(std::string_view name)3423 std::unique_ptr<Expression> SPIRVCodeGenerator::identifier(std::string_view name) {
3424 std::unique_ptr<Expression> expr =
3425 fProgram.fSymbols->instantiateSymbolRef(fContext, name, Position());
3426 return expr ? std::move(expr)
3427 : Poison::Make(Position(), fContext);
3428 }
3429
writeVariableReference(const VariableReference & ref,OutputStream & out)3430 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
3431 const Variable* variable = ref.variable();
3432 switch (variable->layout().fBuiltin) {
3433 case DEVICE_FRAGCOORDS_BUILTIN: {
3434 // Down below, we rewrite raw references to sk_FragCoord with expressions that reference
3435 // DEVICE_FRAGCOORDS_BUILTIN. This is a fake variable that means we need to directly
3436 // access the fragcoord; do so now.
3437 return this->getLValue(*this->identifier("sk_FragCoord"), out)->load(out);
3438 }
3439 case DEVICE_CLOCKWISE_BUILTIN: {
3440 // Down below, we rewrite raw references to sk_Clockwise with expressions that reference
3441 // DEVICE_CLOCKWISE_BUILTIN. This is a fake variable that means we need to directly
3442 // access front facing; do so now.
3443 return this->getLValue(*this->identifier("sk_Clockwise"), out)->load(out);
3444 }
3445 case SK_SECONDARYFRAGCOLOR_BUILTIN: {
3446 if (fCaps.fDualSourceBlendingSupport) {
3447 return this->getLValue(*this->identifier("sk_SecondaryFragColor"), out)->load(out);
3448 } else {
3449 fContext.fErrors->error(ref.position(), "'sk_SecondaryFragColor' not supported");
3450 return NA;
3451 }
3452 }
3453 case SK_FRAGCOORD_BUILTIN: {
3454 if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
3455 return this->getLValue(*this->identifier("sk_FragCoord"), out)->load(out);
3456 }
3457
3458 // Handle inserting use of uniform to flip y when referencing sk_FragCoord.
3459 this->addRTFlipUniform(ref.fPosition);
3460 // Use sk_RTAdjust to compute the flipped coordinate
3461 // Use a uniform to flip the Y coordinate. The new expression will be written in
3462 // terms of $device_FragCoords, which is a fake variable that means "access the
3463 // underlying fragcoords directly without flipping it".
3464 static constexpr char DEVICE_COORDS_NAME[] = "$device_FragCoords";
3465 if (!fProgram.fSymbols->find(DEVICE_COORDS_NAME)) {
3466 AutoAttachPoolToThread attach(fProgram.fPool.get());
3467 Layout layout;
3468 layout.fBuiltin = DEVICE_FRAGCOORDS_BUILTIN;
3469 auto coordsVar = Variable::Make(/*pos=*/Position(),
3470 /*modifiersPosition=*/Position(),
3471 layout,
3472 ModifierFlag::kNone,
3473 fContext.fTypes.fFloat4.get(),
3474 DEVICE_COORDS_NAME,
3475 /*mangledName=*/"",
3476 /*builtin=*/true,
3477 Variable::Storage::kGlobal);
3478 fProgram.fSymbols->add(fContext, std::move(coordsVar));
3479 }
3480 std::unique_ptr<Expression> deviceCoord = this->identifier(DEVICE_COORDS_NAME);
3481 std::unique_ptr<Expression> rtFlip = this->identifier(SKSL_RTFLIP_NAME);
3482 SpvId rtFlipX = this->writeSwizzle(*rtFlip, {SwizzleComponent::X}, out);
3483 SpvId rtFlipY = this->writeSwizzle(*rtFlip, {SwizzleComponent::Y}, out);
3484 SpvId deviceCoordX = this->writeSwizzle(*deviceCoord, {SwizzleComponent::X}, out);
3485 SpvId deviceCoordY = this->writeSwizzle(*deviceCoord, {SwizzleComponent::Y}, out);
3486 SpvId deviceCoordZW = this->writeSwizzle(*deviceCoord, {SwizzleComponent::Z,
3487 SwizzleComponent::W}, out);
3488 // Compute `flippedY = u_RTFlip.y * $device_FragCoords.y`.
3489 SpvId flippedY = this->writeBinaryExpression(
3490 *fContext.fTypes.fFloat, rtFlipY, OperatorKind::STAR,
3491 *fContext.fTypes.fFloat, deviceCoordY,
3492 *fContext.fTypes.fFloat, out);
3493
3494 // Compute `flippedY = u_RTFlip.x + flippedY`.
3495 flippedY = this->writeBinaryExpression(
3496 *fContext.fTypes.fFloat, rtFlipX, OperatorKind::PLUS,
3497 *fContext.fTypes.fFloat, flippedY,
3498 *fContext.fTypes.fFloat, out);
3499
3500 // Return `float4(deviceCoord.x, flippedY, deviceCoord.zw)`.
3501 return this->writeOpCompositeConstruct(*fContext.fTypes.fFloat4,
3502 {deviceCoordX, flippedY, deviceCoordZW},
3503 out);
3504 }
3505 case SK_CLOCKWISE_BUILTIN: {
3506 if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
3507 return this->getLValue(*this->identifier("sk_Clockwise"), out)->load(out);
3508 }
3509
3510 // Apply RTFlip to sk_Clockwise.
3511 this->addRTFlipUniform(ref.fPosition);
3512 // Use a uniform to flip the Y coordinate. The new expression will be written in
3513 // terms of $device_Clockwise, which is a fake variable that means "access the
3514 // underlying FrontFacing directly".
3515 static constexpr char DEVICE_CLOCKWISE_NAME[] = "$device_Clockwise";
3516 if (!fProgram.fSymbols->find(DEVICE_CLOCKWISE_NAME)) {
3517 AutoAttachPoolToThread attach(fProgram.fPool.get());
3518 Layout layout;
3519 layout.fBuiltin = DEVICE_CLOCKWISE_BUILTIN;
3520 auto clockwiseVar = Variable::Make(/*pos=*/Position(),
3521 /*modifiersPosition=*/Position(),
3522 layout,
3523 ModifierFlag::kNone,
3524 fContext.fTypes.fBool.get(),
3525 DEVICE_CLOCKWISE_NAME,
3526 /*mangledName=*/"",
3527 /*builtin=*/true,
3528 Variable::Storage::kGlobal);
3529 fProgram.fSymbols->add(fContext, std::move(clockwiseVar));
3530 }
3531 // FrontFacing in Vulkan is defined in terms of a top-down render target. In Skia,
3532 // we use the default convention of "counter-clockwise face is front".
3533
3534 // Compute `positiveRTFlip = (rtFlip.y > 0)`.
3535 std::unique_ptr<Expression> rtFlip = this->identifier(SKSL_RTFLIP_NAME);
3536 SpvId rtFlipY = this->writeSwizzle(*rtFlip, {SwizzleComponent::Y}, out);
3537 SpvId zero = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
3538 SpvId positiveRTFlip = this->writeBinaryExpression(
3539 *fContext.fTypes.fFloat, rtFlipY, OperatorKind::GT,
3540 *fContext.fTypes.fFloat, zero,
3541 *fContext.fTypes.fBool, out);
3542
3543 // Compute `positiveRTFlip ^^ $device_Clockwise`.
3544 std::unique_ptr<Expression> deviceClockwise = this->identifier(DEVICE_CLOCKWISE_NAME);
3545 SpvId deviceClockwiseID = this->writeExpression(*deviceClockwise, out);
3546 return this->writeBinaryExpression(
3547 *fContext.fTypes.fBool, positiveRTFlip, OperatorKind::LOGICALXOR,
3548 *fContext.fTypes.fBool, deviceClockwiseID,
3549 *fContext.fTypes.fBool, out);
3550 }
3551 default: {
3552 // Constant-propagate variables that have a known compile-time value.
3553 if (const Expression* expr = ConstantFolder::GetConstantValueOrNull(ref)) {
3554 return this->writeExpression(*expr, out);
3555 }
3556
3557 // A reference to a sampler variable at global scope with synthesized texture/sampler
3558 // backing should construct a function-scope combined image-sampler from the synthesized
3559 // constituents. This is the case in which a sample intrinsic was invoked.
3560 //
3561 // Variable references to opaque handles (texture/sampler) that appear as the argument
3562 // of a user-defined function call are explicitly handled in writeFunctionCallArgument.
3563 if (fUseTextureSamplerPairs && variable->type().isSampler()) {
3564 if (const auto* p = fSynthesizedSamplerMap.find(variable)) {
3565 SpvId* imgPtr = fVariableMap.find((*p)->fTexture.get());
3566 SpvId* samplerPtr = fVariableMap.find((*p)->fSampler.get());
3567 SkASSERT(imgPtr);
3568 SkASSERT(samplerPtr);
3569
3570 SpvId img = this->writeOpLoad(this->getType((*p)->fTexture->type()),
3571 Precision::kDefault, *imgPtr, out);
3572 SpvId sampler = this->writeOpLoad(this->getType((*p)->fSampler->type()),
3573 Precision::kDefault,
3574 *samplerPtr,
3575 out);
3576 SpvId result = this->nextId(nullptr);
3577 this->writeInstruction(SpvOpSampledImage,
3578 this->getType(variable->type()),
3579 result,
3580 img,
3581 sampler,
3582 out);
3583 return result;
3584 }
3585 SkDEBUGFAIL("sampler missing from fSynthesizedSamplerMap");
3586 }
3587 return this->getLValue(ref, out)->load(out);
3588 }
3589 }
3590 }
3591
writeIndexExpression(const IndexExpression & expr,OutputStream & out)3592 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
3593 if (expr.base()->type().isVector()) {
3594 SpvId base = this->writeExpression(*expr.base(), out);
3595 SpvId index = this->writeExpression(*expr.index(), out);
3596 SpvId result = this->nextId(nullptr);
3597 this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.type()), result, base,
3598 index, out);
3599 return result;
3600 }
3601 return getLValue(expr, out)->load(out);
3602 }
3603
writeFieldAccess(const FieldAccess & f,OutputStream & out)3604 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
3605 return getLValue(f, out)->load(out);
3606 }
3607
writeSwizzle(const Expression & baseExpr,const ComponentArray & components,OutputStream & out)3608 SpvId SPIRVCodeGenerator::writeSwizzle(const Expression& baseExpr,
3609 const ComponentArray& components,
3610 OutputStream& out) {
3611 size_t count = components.size();
3612 const Type& type = baseExpr.type().componentType().toCompound(fContext, count, /*rows=*/1);
3613 SpvId base = this->writeExpression(baseExpr, out);
3614 if (count == 1) {
3615 return this->writeOpCompositeExtract(type, base, components[0], out);
3616 }
3617
3618 SpvId result = this->nextId(&type);
3619 this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
3620 this->writeWord(this->getType(type), out);
3621 this->writeWord(result, out);
3622 this->writeWord(base, out);
3623 this->writeWord(base, out);
3624 for (int component : components) {
3625 this->writeWord(component, out);
3626 }
3627 return result;
3628 }
3629
writeSwizzle(const Swizzle & swizzle,OutputStream & out)3630 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
3631 return this->writeSwizzle(*swizzle.base(), swizzle.components(), out);
3632 }
3633
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,bool writeComponentwiseIfMatrix,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3634 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, const Type& operandType,
3635 SpvId lhs, SpvId rhs,
3636 bool writeComponentwiseIfMatrix,
3637 SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
3638 SpvOp_ ifBool, OutputStream& out) {
3639 SpvOp_ op = pick_by_type(operandType, ifFloat, ifInt, ifUInt, ifBool);
3640 if (op == SpvOpUndef) {
3641 fContext.fErrors->error(operandType.fPosition,
3642 "unsupported operand for binary expression: " + operandType.description());
3643 return NA;
3644 }
3645 if (writeComponentwiseIfMatrix && operandType.isMatrix()) {
3646 return this->writeComponentwiseMatrixBinary(resultType, lhs, rhs, op, out);
3647 }
3648 SpvId result = this->nextId(&resultType);
3649 this->writeInstruction(op, this->getType(resultType), result, lhs, rhs, out);
3650 return result;
3651 }
3652
writeBinaryOperationComponentwiseIfMatrix(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3653 SpvId SPIRVCodeGenerator::writeBinaryOperationComponentwiseIfMatrix(const Type& resultType,
3654 const Type& operandType,
3655 SpvId lhs, SpvId rhs,
3656 SpvOp_ ifFloat, SpvOp_ ifInt,
3657 SpvOp_ ifUInt, SpvOp_ ifBool,
3658 OutputStream& out) {
3659 return this->writeBinaryOperation(resultType, operandType, lhs, rhs,
3660 /*writeComponentwiseIfMatrix=*/true,
3661 ifFloat, ifInt, ifUInt, ifBool, out);
3662 }
3663
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3664 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, const Type& operandType,
3665 SpvId lhs, SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
3666 SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
3667 return this->writeBinaryOperation(resultType, operandType, lhs, rhs,
3668 /*writeComponentwiseIfMatrix=*/false,
3669 ifFloat, ifInt, ifUInt, ifBool, out);
3670 }
3671
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)3672 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
3673 OutputStream& out) {
3674 if (operandType.isVector()) {
3675 SpvId result = this->nextId(nullptr);
3676 this->writeInstruction(op, this->getType(*fContext.fTypes.fBool), result, id, out);
3677 return result;
3678 }
3679 return id;
3680 }
3681
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)3682 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
3683 SpvOp_ floatOperator, SpvOp_ intOperator,
3684 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
3685 OutputStream& out) {
3686 SpvOp_ compareOp = is_float(operandType) ? floatOperator : intOperator;
3687 SkASSERT(operandType.isMatrix());
3688 const Type& columnType = operandType.componentType().toCompound(fContext,
3689 operandType.rows(),
3690 1);
3691 SpvId bvecType = this->getType(fContext.fTypes.fBool->toCompound(fContext,
3692 operandType.rows(),
3693 1));
3694 SpvId boolType = this->getType(*fContext.fTypes.fBool);
3695 SpvId result = 0;
3696 for (int i = 0; i < operandType.columns(); i++) {
3697 SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
3698 SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
3699 SpvId compare = this->nextId(&operandType);
3700 this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
3701 SpvId merge = this->nextId(nullptr);
3702 this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
3703 if (result != 0) {
3704 SpvId next = this->nextId(nullptr);
3705 this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
3706 result = next;
3707 } else {
3708 result = merge;
3709 }
3710 }
3711 return result;
3712 }
3713
writeComponentwiseMatrixUnary(const Type & operandType,SpvId operand,SpvOp_ op,OutputStream & out)3714 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixUnary(const Type& operandType,
3715 SpvId operand,
3716 SpvOp_ op,
3717 OutputStream& out) {
3718 SkASSERT(operandType.isMatrix());
3719 const Type& columnType = operandType.columnType(fContext);
3720 SpvId columnTypeId = this->getType(columnType);
3721
3722 STArray<4, SpvId> columns;
3723 for (int i = 0; i < operandType.columns(); i++) {
3724 SpvId srcColumn = this->writeOpCompositeExtract(columnType, operand, i, out);
3725 SpvId dstColumn = this->nextId(&operandType);
3726 this->writeInstruction(op, columnTypeId, dstColumn, srcColumn, out);
3727 columns.push_back(dstColumn);
3728 }
3729
3730 return this->writeOpCompositeConstruct(operandType, columns, out);
3731 }
3732
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ op,OutputStream & out)3733 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
3734 SpvId rhs, SpvOp_ op, OutputStream& out) {
3735 SkASSERT(operandType.isMatrix());
3736 const Type& columnType = operandType.columnType(fContext);
3737 SpvId columnTypeId = this->getType(columnType);
3738
3739 STArray<4, SpvId> columns;
3740 for (int i = 0; i < operandType.columns(); i++) {
3741 SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
3742 SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
3743 columns.push_back(this->nextId(&operandType));
3744 this->writeInstruction(op, columnTypeId, columns[i], columnL, columnR, out);
3745 }
3746 return this->writeOpCompositeConstruct(operandType, columns, out);
3747 }
3748
writeReciprocal(const Type & type,SpvId value,OutputStream & out)3749 SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) {
3750 SkASSERT(type.isFloat());
3751 SpvId one = this->writeLiteral(1.0, type);
3752 SpvId reciprocal = this->nextId(&type);
3753 this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out);
3754 return reciprocal;
3755 }
3756
splat(const Type & type,SpvId id,OutputStream & out)3757 SpvId SPIRVCodeGenerator::splat(const Type& type, SpvId id, OutputStream& out) {
3758 if (type.isScalar()) {
3759 // Scalars require no additional work; we can return the passed-in ID as is.
3760 } else {
3761 SkASSERT(type.isVector() || type.isMatrix());
3762 bool isMatrix = type.isMatrix();
3763
3764 // Splat the input scalar across a vector.
3765 int vectorSize = (isMatrix ? type.rows() : type.columns());
3766 const Type& vectorType = type.componentType().toCompound(fContext, vectorSize, /*rows=*/1);
3767
3768 STArray<4, SpvId> values;
3769 values.push_back_n(/*n=*/vectorSize, /*t=*/id);
3770 id = this->writeOpCompositeConstruct(vectorType, values, out);
3771
3772 if (isMatrix) {
3773 // Splat the newly-synthesized vector into a matrix.
3774 STArray<4, SpvId> matArguments;
3775 matArguments.push_back_n(/*n=*/type.columns(), /*t=*/id);
3776 id = this->writeOpCompositeConstruct(type, matArguments, out);
3777 }
3778 }
3779
3780 return id;
3781 }
3782
types_match(const Type & a,const Type & b)3783 static bool types_match(const Type& a, const Type& b) {
3784 if (a.matches(b)) {
3785 return true;
3786 }
3787 return (a.typeKind() == b.typeKind()) &&
3788 (a.isScalar() || a.isVector() || a.isMatrix()) &&
3789 (a.columns() == b.columns() && a.rows() == b.rows()) &&
3790 a.componentType().numberKind() == b.componentType().numberKind();
3791 }
3792
writeDecomposedMatrixVectorMultiply(const Type & leftType,SpvId lhs,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)3793 SpvId SPIRVCodeGenerator::writeDecomposedMatrixVectorMultiply(const Type& leftType,
3794 SpvId lhs,
3795 const Type& rightType,
3796 SpvId rhs,
3797 const Type& resultType,
3798 OutputStream& out) {
3799 SpvId sum = NA;
3800 const Type& columnType = leftType.columnType(fContext);
3801 const Type& scalarType = rightType.componentType();
3802
3803 for (int n = 0; n < leftType.rows(); ++n) {
3804 // Extract mat[N] from the matrix.
3805 SpvId matN = this->writeOpCompositeExtract(columnType, lhs, n, out);
3806
3807 // Extract vec[N] from the vector.
3808 SpvId vecN = this->writeOpCompositeExtract(scalarType, rhs, n, out);
3809
3810 // Multiply them together.
3811 SpvId product = this->writeBinaryExpression(columnType, matN, OperatorKind::STAR,
3812 scalarType, vecN,
3813 columnType, out);
3814
3815 // Sum all the components together.
3816 if (sum == NA) {
3817 sum = product;
3818 } else {
3819 sum = this->writeBinaryExpression(columnType, sum, OperatorKind::PLUS,
3820 columnType, product,
3821 columnType, out);
3822 }
3823 }
3824
3825 return sum;
3826 }
3827
writeBinaryExpression(const Type & leftType,SpvId lhs,Operator op,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)3828 SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
3829 const Type& rightType, SpvId rhs,
3830 const Type& resultType, OutputStream& out) {
3831 // The comma operator ignores the type of the left-hand side entirely.
3832 if (op.kind() == Operator::Kind::COMMA) {
3833 return rhs;
3834 }
3835 // overall type we are operating on: float2, int, uint4...
3836 const Type* operandType;
3837 if (types_match(leftType, rightType)) {
3838 operandType = &leftType;
3839 } else {
3840 // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
3841 // handling in SPIR-V
3842 if (leftType.isVector() && rightType.isNumber()) {
3843 if (resultType.componentType().isFloat()) {
3844 switch (op.kind()) {
3845 case Operator::Kind::SLASH: {
3846 rhs = this->writeReciprocal(rightType, rhs, out);
3847 [[fallthrough]];
3848 }
3849 case Operator::Kind::STAR: {
3850 SpvId result = this->nextId(&resultType);
3851 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
3852 result, lhs, rhs, out);
3853 return result;
3854 }
3855 default:
3856 break;
3857 }
3858 }
3859 // Vectorize the right-hand side.
3860 STArray<4, SpvId> arguments;
3861 arguments.push_back_n(/*n=*/leftType.columns(), /*t=*/rhs);
3862 rhs = this->writeOpCompositeConstruct(leftType, arguments, out);
3863 operandType = &leftType;
3864 } else if (rightType.isVector() && leftType.isNumber()) {
3865 if (resultType.componentType().isFloat()) {
3866 if (op.kind() == Operator::Kind::STAR) {
3867 SpvId result = this->nextId(&resultType);
3868 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
3869 result, rhs, lhs, out);
3870 return result;
3871 }
3872 }
3873 // Vectorize the left-hand side.
3874 STArray<4, SpvId> arguments;
3875 arguments.push_back_n(/*n=*/rightType.columns(), /*t=*/lhs);
3876 lhs = this->writeOpCompositeConstruct(rightType, arguments, out);
3877 operandType = &rightType;
3878 } else if (leftType.isMatrix()) {
3879 if (op.kind() == Operator::Kind::STAR) {
3880 // When the rewriteMatrixVectorMultiply bit is set, we rewrite medium-precision
3881 // matrix * vector multiplication as (mat[0]*vec[0] + ... + mat[N]*vec[N]).
3882 if (fCaps.fRewriteMatrixVectorMultiply &&
3883 rightType.isVector() &&
3884 !resultType.highPrecision()) {
3885 return this->writeDecomposedMatrixVectorMultiply(leftType, lhs, rightType, rhs,
3886 resultType, out);
3887 }
3888
3889 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
3890 SpvOp_ spvop;
3891 if (rightType.isMatrix()) {
3892 spvop = SpvOpMatrixTimesMatrix;
3893 } else if (rightType.isVector()) {
3894 spvop = SpvOpMatrixTimesVector;
3895 } else {
3896 SkASSERT(rightType.isScalar());
3897 spvop = SpvOpMatrixTimesScalar;
3898 }
3899 SpvId result = this->nextId(&resultType);
3900 this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
3901 return result;
3902 } else {
3903 // Matrix-op-vector is not supported in GLSL/SkSL for non-multiplication ops; we
3904 // expect to have a scalar here.
3905 SkASSERT(rightType.isScalar());
3906
3907 // Splat rhs across an entire matrix so we can reuse the matrix-op-matrix path.
3908 SpvId rhsMatrix = this->splat(leftType, rhs, out);
3909
3910 // Perform this operation as matrix-op-matrix.
3911 return this->writeBinaryExpression(leftType, lhs, op, leftType, rhsMatrix,
3912 resultType, out);
3913 }
3914 } else if (rightType.isMatrix()) {
3915 if (op.kind() == Operator::Kind::STAR) {
3916 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
3917 SpvId result = this->nextId(&resultType);
3918 if (leftType.isVector()) {
3919 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType),
3920 result, lhs, rhs, out);
3921 } else {
3922 SkASSERT(leftType.isScalar());
3923 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType),
3924 result, rhs, lhs, out);
3925 }
3926 return result;
3927 } else {
3928 // Vector-op-matrix is not supported in GLSL/SkSL for non-multiplication ops; we
3929 // expect to have a scalar here.
3930 SkASSERT(leftType.isScalar());
3931
3932 // Splat lhs across an entire matrix so we can reuse the matrix-op-matrix path.
3933 SpvId lhsMatrix = this->splat(rightType, lhs, out);
3934
3935 // Perform this operation as matrix-op-matrix.
3936 return this->writeBinaryExpression(rightType, lhsMatrix, op, rightType, rhs,
3937 resultType, out);
3938 }
3939 } else {
3940 fContext.fErrors->error(leftType.fPosition, "unsupported mixed-type expression");
3941 return NA;
3942 }
3943 }
3944
3945 switch (op.kind()) {
3946 case Operator::Kind::EQEQ: {
3947 if (operandType->isMatrix()) {
3948 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
3949 SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
3950 }
3951 if (operandType->isStruct()) {
3952 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
3953 }
3954 if (operandType->isArray()) {
3955 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
3956 }
3957 SkASSERT(resultType.isBoolean());
3958 const Type* tmpType;
3959 if (operandType->isVector()) {
3960 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
3961 operandType->columns(),
3962 operandType->rows());
3963 } else {
3964 tmpType = &resultType;
3965 }
3966 if (lhs == rhs) {
3967 // This ignores the effects of NaN.
3968 return this->writeOpConstantTrue(*fContext.fTypes.fBool);
3969 }
3970 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
3971 SpvOpFOrdEqual, SpvOpIEqual,
3972 SpvOpIEqual, SpvOpLogicalEqual, out),
3973 *operandType, SpvOpAll, out);
3974 }
3975 case Operator::Kind::NEQ:
3976 if (operandType->isMatrix()) {
3977 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFUnordNotEqual,
3978 SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
3979 }
3980 if (operandType->isStruct()) {
3981 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
3982 }
3983 if (operandType->isArray()) {
3984 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
3985 }
3986 [[fallthrough]];
3987 case Operator::Kind::LOGICALXOR:
3988 SkASSERT(resultType.isBoolean());
3989 const Type* tmpType;
3990 if (operandType->isVector()) {
3991 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
3992 operandType->columns(),
3993 operandType->rows());
3994 } else {
3995 tmpType = &resultType;
3996 }
3997 if (lhs == rhs) {
3998 // This ignores the effects of NaN.
3999 return this->writeOpConstantFalse(*fContext.fTypes.fBool);
4000 }
4001 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
4002 SpvOpFUnordNotEqual, SpvOpINotEqual,
4003 SpvOpINotEqual, SpvOpLogicalNotEqual,
4004 out),
4005 *operandType, SpvOpAny, out);
4006 case Operator::Kind::GT:
4007 SkASSERT(resultType.isBoolean());
4008 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
4009 SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
4010 SpvOpUGreaterThan, SpvOpUndef, out);
4011 case Operator::Kind::LT:
4012 SkASSERT(resultType.isBoolean());
4013 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
4014 SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
4015 case Operator::Kind::GTEQ:
4016 SkASSERT(resultType.isBoolean());
4017 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
4018 SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
4019 SpvOpUGreaterThanEqual, SpvOpUndef, out);
4020 case Operator::Kind::LTEQ:
4021 SkASSERT(resultType.isBoolean());
4022 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
4023 SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
4024 SpvOpULessThanEqual, SpvOpUndef, out);
4025 case Operator::Kind::PLUS:
4026 return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
4027 lhs, rhs, SpvOpFAdd, SpvOpIAdd,
4028 SpvOpIAdd, SpvOpUndef, out);
4029 case Operator::Kind::MINUS:
4030 return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
4031 lhs, rhs, SpvOpFSub, SpvOpISub,
4032 SpvOpISub, SpvOpUndef, out);
4033 case Operator::Kind::STAR:
4034 if (leftType.isMatrix() && rightType.isMatrix()) {
4035 // matrix multiply
4036 SpvId result = this->nextId(&resultType);
4037 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
4038 lhs, rhs, out);
4039 return result;
4040 }
4041 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
4042 SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
4043 case Operator::Kind::SLASH:
4044 return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
4045 lhs, rhs, SpvOpFDiv, SpvOpSDiv,
4046 SpvOpUDiv, SpvOpUndef, out);
4047 case Operator::Kind::PERCENT:
4048 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
4049 SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
4050 case Operator::Kind::SHL:
4051 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4052 SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
4053 SpvOpUndef, out);
4054 case Operator::Kind::SHR:
4055 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4056 SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
4057 SpvOpUndef, out);
4058 case Operator::Kind::BITWISEAND:
4059 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4060 SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
4061 case Operator::Kind::BITWISEOR:
4062 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4063 SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
4064 case Operator::Kind::BITWISEXOR:
4065 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4066 SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
4067 default:
4068 fContext.fErrors->error(Position(), "unsupported token");
4069 return NA;
4070 }
4071 }
4072
writeArrayComparison(const Type & arrayType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)4073 SpvId SPIRVCodeGenerator::writeArrayComparison(const Type& arrayType, SpvId lhs, Operator op,
4074 SpvId rhs, OutputStream& out) {
4075 // The inputs must be arrays, and the op must be == or !=.
4076 SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
4077 SkASSERT(arrayType.isArray());
4078 const Type& componentType = arrayType.componentType();
4079 const int arraySize = arrayType.columns();
4080 SkASSERT(arraySize > 0);
4081
4082 // Synthesize equality checks for each item in the array.
4083 const Type& boolType = *fContext.fTypes.fBool;
4084 SpvId allComparisons = NA;
4085 for (int index = 0; index < arraySize; ++index) {
4086 // Get the left and right item in the array.
4087 SpvId itemL = this->writeOpCompositeExtract(componentType, lhs, index, out);
4088 SpvId itemR = this->writeOpCompositeExtract(componentType, rhs, index, out);
4089 // Use `writeBinaryExpression` with the requested == or != operator on these items.
4090 SpvId comparison = this->writeBinaryExpression(componentType, itemL, op,
4091 componentType, itemR, boolType, out);
4092 // Merge this comparison result with all the other comparisons we've done.
4093 allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
4094 }
4095 return allComparisons;
4096 }
4097
writeStructComparison(const Type & structType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)4098 SpvId SPIRVCodeGenerator::writeStructComparison(const Type& structType, SpvId lhs, Operator op,
4099 SpvId rhs, OutputStream& out) {
4100 // The inputs must be structs containing fields, and the op must be == or !=.
4101 SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
4102 SkASSERT(structType.isStruct());
4103 SkSpan<const Field> fields = structType.fields();
4104 SkASSERT(!fields.empty());
4105
4106 // Synthesize equality checks for each field in the struct.
4107 const Type& boolType = *fContext.fTypes.fBool;
4108 SpvId allComparisons = NA;
4109 for (int index = 0; index < (int)fields.size(); ++index) {
4110 // Get the left and right versions of this field.
4111 const Type& fieldType = *fields[index].fType;
4112
4113 SpvId fieldL = this->writeOpCompositeExtract(fieldType, lhs, index, out);
4114 SpvId fieldR = this->writeOpCompositeExtract(fieldType, rhs, index, out);
4115 // Use `writeBinaryExpression` with the requested == or != operator on these fields.
4116 SpvId comparison = this->writeBinaryExpression(fieldType, fieldL, op, fieldType, fieldR,
4117 boolType, out);
4118 // Merge this comparison result with all the other comparisons we've done.
4119 allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
4120 }
4121 return allComparisons;
4122 }
4123
mergeComparisons(SpvId comparison,SpvId allComparisons,Operator op,OutputStream & out)4124 SpvId SPIRVCodeGenerator::mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op,
4125 OutputStream& out) {
4126 // If this is the first entry, we don't need to merge comparison results with anything.
4127 if (allComparisons == NA) {
4128 return comparison;
4129 }
4130 // Use LogicalAnd or LogicalOr to combine the comparison with all the other comparisons.
4131 const Type& boolType = *fContext.fTypes.fBool;
4132 SpvId boolTypeId = this->getType(boolType);
4133 SpvId logicalOp = this->nextId(&boolType);
4134 switch (op.kind()) {
4135 case Operator::Kind::EQEQ:
4136 this->writeInstruction(SpvOpLogicalAnd, boolTypeId, logicalOp,
4137 comparison, allComparisons, out);
4138 break;
4139 case Operator::Kind::NEQ:
4140 this->writeInstruction(SpvOpLogicalOr, boolTypeId, logicalOp,
4141 comparison, allComparisons, out);
4142 break;
4143 default:
4144 SkDEBUGFAILF("mergeComparisons only supports == and !=, not %s", op.operatorName());
4145 return NA;
4146 }
4147 return logicalOp;
4148 }
4149
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)4150 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
4151 const Expression* left = b.left().get();
4152 const Expression* right = b.right().get();
4153 Operator op = b.getOperator();
4154
4155 switch (op.kind()) {
4156 case Operator::Kind::EQ: {
4157 // Handles assignment.
4158 SpvId rhs = this->writeExpression(*right, out);
4159 this->getLValue(*left, out)->store(rhs, out);
4160 return rhs;
4161 }
4162 case Operator::Kind::LOGICALAND:
4163 // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
4164 return this->writeLogicalAnd(*b.left(), *b.right(), out);
4165
4166 case Operator::Kind::LOGICALOR:
4167 // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
4168 return this->writeLogicalOr(*b.left(), *b.right(), out);
4169
4170 default:
4171 break;
4172 }
4173
4174 std::unique_ptr<LValue> lvalue;
4175 SpvId lhs;
4176 if (op.isAssignment()) {
4177 lvalue = this->getLValue(*left, out);
4178 lhs = lvalue->load(out);
4179 } else {
4180 lvalue = nullptr;
4181 lhs = this->writeExpression(*left, out);
4182 }
4183
4184 SpvId rhs = this->writeExpression(*right, out);
4185 SpvId result = this->writeBinaryExpression(left->type(), lhs, op.removeAssignment(),
4186 right->type(), rhs, b.type(), out);
4187 if (lvalue) {
4188 lvalue->store(result, out);
4189 }
4190 return result;
4191 }
4192
writeLogicalAnd(const Expression & left,const Expression & right,OutputStream & out)4193 SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right,
4194 OutputStream& out) {
4195 SpvId falseConstant = this->writeLiteral(0.0, *fContext.fTypes.fBool);
4196 SpvId lhs = this->writeExpression(left, out);
4197
4198 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4199
4200 SpvId rhsLabel = this->nextId(nullptr);
4201 SpvId end = this->nextId(nullptr);
4202 SpvId lhsBlock = fCurrentBlock;
4203 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4204 this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
4205 this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
4206 SpvId rhs = this->writeExpression(right, out);
4207 SpvId rhsBlock = fCurrentBlock;
4208 this->writeInstruction(SpvOpBranch, end, out);
4209 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4210 SpvId result = this->nextId(nullptr);
4211 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, falseConstant,
4212 lhsBlock, rhs, rhsBlock, out);
4213
4214 return result;
4215 }
4216
writeLogicalOr(const Expression & left,const Expression & right,OutputStream & out)4217 SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right,
4218 OutputStream& out) {
4219 SpvId trueConstant = this->writeLiteral(1.0, *fContext.fTypes.fBool);
4220 SpvId lhs = this->writeExpression(left, out);
4221
4222 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4223
4224 SpvId rhsLabel = this->nextId(nullptr);
4225 SpvId end = this->nextId(nullptr);
4226 SpvId lhsBlock = fCurrentBlock;
4227 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4228 this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
4229 this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
4230 SpvId rhs = this->writeExpression(right, out);
4231 SpvId rhsBlock = fCurrentBlock;
4232 this->writeInstruction(SpvOpBranch, end, out);
4233 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4234 SpvId result = this->nextId(nullptr);
4235 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, trueConstant,
4236 lhsBlock, rhs, rhsBlock, out);
4237
4238 return result;
4239 }
4240
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)4241 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
4242 const Type& type = t.type();
4243 SpvId test = this->writeExpression(*t.test(), out);
4244 if (t.ifTrue()->type().columns() == 1 &&
4245 Analysis::IsCompileTimeConstant(*t.ifTrue()) &&
4246 Analysis::IsCompileTimeConstant(*t.ifFalse())) {
4247 // both true and false are constants, can just use OpSelect
4248 SpvId result = this->nextId(nullptr);
4249 SpvId trueId = this->writeExpression(*t.ifTrue(), out);
4250 SpvId falseId = this->writeExpression(*t.ifFalse(), out);
4251 this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId,
4252 out);
4253 return result;
4254 }
4255
4256 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4257
4258 // was originally using OpPhi to choose the result, but for some reason that is crashing on
4259 // Adreno. Switched to storing the result in a temp variable as glslang does.
4260 SpvId var = this->nextId(nullptr);
4261 this->writeInstruction(SpvOpVariable, this->getPointerType(type, StorageClass::kFunction),
4262 var, SpvStorageClassFunction, fVariableBuffer);
4263 SpvId trueLabel = this->nextId(nullptr);
4264 SpvId falseLabel = this->nextId(nullptr);
4265 SpvId end = this->nextId(nullptr);
4266 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4267 this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
4268 this->writeLabel(trueLabel, kBranchIsOnPreviousLine, out);
4269 this->writeOpStore(StorageClass::kFunction, var, this->writeExpression(*t.ifTrue(), out), out);
4270 this->writeInstruction(SpvOpBranch, end, out);
4271 this->writeLabel(falseLabel, kBranchIsAbove, conditionalOps, out);
4272 this->writeOpStore(StorageClass::kFunction, var, this->writeExpression(*t.ifFalse(), out), out);
4273 this->writeInstruction(SpvOpBranch, end, out);
4274 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4275 SpvId result = this->nextId(&type);
4276 this->writeInstruction(SpvOpLoad, this->getType(type), result, var, out);
4277
4278 return result;
4279 }
4280
writePrefixExpression(const PrefixExpression & p,OutputStream & out)4281 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
4282 const Type& type = p.type();
4283 if (p.getOperator().kind() == Operator::Kind::MINUS) {
4284 SpvOp_ negateOp = pick_by_type(type, SpvOpFNegate, SpvOpSNegate, SpvOpSNegate, SpvOpUndef);
4285 SkASSERT(negateOp != SpvOpUndef);
4286 SpvId expr = this->writeExpression(*p.operand(), out);
4287 if (type.isMatrix()) {
4288 return this->writeComponentwiseMatrixUnary(type, expr, negateOp, out);
4289 }
4290 SpvId result = this->nextId(&type);
4291 SpvId typeId = this->getType(type);
4292 this->writeInstruction(negateOp, typeId, result, expr, out);
4293 return result;
4294 }
4295 switch (p.getOperator().kind()) {
4296 case Operator::Kind::PLUS:
4297 return this->writeExpression(*p.operand(), out);
4298
4299 case Operator::Kind::PLUSPLUS: {
4300 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4301 SpvId one = this->writeLiteral(1.0, type.componentType());
4302 one = this->splat(type, one, out);
4303 SpvId result = this->writeBinaryOperationComponentwiseIfMatrix(type, type,
4304 lv->load(out), one,
4305 SpvOpFAdd, SpvOpIAdd,
4306 SpvOpIAdd, SpvOpUndef,
4307 out);
4308 lv->store(result, out);
4309 return result;
4310 }
4311 case Operator::Kind::MINUSMINUS: {
4312 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4313 SpvId one = this->writeLiteral(1.0, type.componentType());
4314 one = this->splat(type, one, out);
4315 SpvId result = this->writeBinaryOperationComponentwiseIfMatrix(type, type,
4316 lv->load(out), one,
4317 SpvOpFSub, SpvOpISub,
4318 SpvOpISub, SpvOpUndef,
4319 out);
4320 lv->store(result, out);
4321 return result;
4322 }
4323 case Operator::Kind::LOGICALNOT: {
4324 SkASSERT(p.operand()->type().isBoolean());
4325 SpvId result = this->nextId(nullptr);
4326 this->writeInstruction(SpvOpLogicalNot, this->getType(type), result,
4327 this->writeExpression(*p.operand(), out), out);
4328 return result;
4329 }
4330 case Operator::Kind::BITWISENOT: {
4331 SpvId result = this->nextId(nullptr);
4332 this->writeInstruction(SpvOpNot, this->getType(type), result,
4333 this->writeExpression(*p.operand(), out), out);
4334 return result;
4335 }
4336 default:
4337 SkDEBUGFAILF("unsupported prefix expression: %s",
4338 p.description(OperatorPrecedence::kExpression).c_str());
4339 return NA;
4340 }
4341 }
4342
writePostfixExpression(const PostfixExpression & p,OutputStream & out)4343 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
4344 const Type& type = p.type();
4345 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4346 SpvId result = lv->load(out);
4347 SpvId one = this->writeLiteral(1.0, type.componentType());
4348 one = this->splat(type, one, out);
4349 switch (p.getOperator().kind()) {
4350 case Operator::Kind::PLUSPLUS: {
4351 SpvId temp = this->writeBinaryOperationComponentwiseIfMatrix(type, type, result, one,
4352 SpvOpFAdd, SpvOpIAdd,
4353 SpvOpIAdd, SpvOpUndef,
4354 out);
4355 lv->store(temp, out);
4356 return result;
4357 }
4358 case Operator::Kind::MINUSMINUS: {
4359 SpvId temp = this->writeBinaryOperationComponentwiseIfMatrix(type, type, result, one,
4360 SpvOpFSub, SpvOpISub,
4361 SpvOpISub, SpvOpUndef,
4362 out);
4363 lv->store(temp, out);
4364 return result;
4365 }
4366 default:
4367 SkDEBUGFAILF("unsupported postfix expression %s",
4368 p.description(OperatorPrecedence::kExpression).c_str());
4369 return NA;
4370 }
4371 }
4372
writeLiteral(const Literal & l)4373 SpvId SPIRVCodeGenerator::writeLiteral(const Literal& l) {
4374 return this->writeLiteral(l.value(), l.type());
4375 }
4376
writeLiteral(double value,const Type & type)4377 SpvId SPIRVCodeGenerator::writeLiteral(double value, const Type& type) {
4378 switch (type.numberKind()) {
4379 case Type::NumberKind::kFloat: {
4380 float floatVal = value;
4381 int32_t valueBits;
4382 memcpy(&valueBits, &floatVal, sizeof(valueBits));
4383 return this->writeOpConstant(type, valueBits);
4384 }
4385 case Type::NumberKind::kBoolean: {
4386 return value ? this->writeOpConstantTrue(type)
4387 : this->writeOpConstantFalse(type);
4388 }
4389 default: {
4390 return this->writeOpConstant(type, (SKSL_INT)value);
4391 }
4392 }
4393 }
4394
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)4395 void SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
4396 SpvId result = fFunctionMap[{&f, fActiveSpecializationIndex}];
4397 SpvId returnTypeId = this->getType(f.returnType());
4398 SpvId functionTypeId = this->getFunctionType(f);
4399 this->writeInstruction(SpvOpFunction, returnTypeId, result,
4400 SpvFunctionControlMaskNone, functionTypeId, out);
4401 std::string mangledName = f.mangledName();
4402
4403 // For specialized functions, tack on `_param1_param2` to the function name.
4404 Analysis::GetParameterMappingsForFunction(
4405 f, fSpecializationInfo, fActiveSpecializationIndex,
4406 [&](int, const Variable*, const Expression* expr) {
4407 std::string name = expr->description();
4408 std::replace_if(name.begin(), name.end(), [](char c) { return !isalnum(c); }, '_');
4409
4410 mangledName += "_" + name;
4411 });
4412
4413 this->writeInstruction(SpvOpName,
4414 result,
4415 std::string_view(mangledName.c_str(), mangledName.size()),
4416 fNameBuffer);
4417 for (const Variable* parameter : f.parameters()) {
4418 const Variable* specializedVar = nullptr;
4419 if (fActiveSpecialization) {
4420 if (const Expression** specializedExpr = fActiveSpecialization->find(parameter)) {
4421 if ((*specializedExpr)->is<FieldAccess>()) {
4422 continue;
4423 }
4424 SkASSERT((*specializedExpr)->is<VariableReference>());
4425 specializedVar = (*specializedExpr)->as<VariableReference>().variable();
4426 }
4427 }
4428
4429 if (fUseTextureSamplerPairs && parameter->type().isSampler()) {
4430 auto [texture, sampler] = this->synthesizeTextureAndSampler(*parameter);
4431
4432 SpvId textureId = this->nextId(nullptr);
4433 fVariableMap.set(texture, textureId);
4434
4435 SpvId textureType = this->getFunctionParameterType(texture->type(), texture->layout());
4436 this->writeInstruction(SpvOpFunctionParameter, textureType, textureId, out);
4437
4438 if (specializedVar) {
4439 const auto* p = fSynthesizedSamplerMap.find(specializedVar);
4440 SkASSERT(p);
4441 const SpvId* uniformId = fVariableMap.find((*p)->fSampler.get());
4442 SkASSERT(uniformId);
4443 fVariableMap.set(sampler, *uniformId);
4444 } else {
4445 SpvId samplerId = this->nextId(nullptr);
4446 fVariableMap.set(sampler, samplerId);
4447
4448 SpvId samplerType =
4449 this->getFunctionParameterType(sampler->type(), kDefaultTypeLayout);
4450 this->writeInstruction(SpvOpFunctionParameter, samplerType, samplerId, out);
4451 }
4452 } else {
4453 if (specializedVar) {
4454 const SpvId* uniformId = fVariableMap.find(specializedVar);
4455 SkASSERT(uniformId);
4456 fVariableMap.set(parameter, *uniformId);
4457 } else {
4458 SpvId id = this->nextId(nullptr);
4459 fVariableMap.set(parameter, id);
4460
4461 SpvId type = this->getFunctionParameterType(parameter->type(), parameter->layout());
4462 this->writeInstruction(SpvOpFunctionParameter, type, id, out);
4463 }
4464 }
4465 }
4466 }
4467
writeFunction(const FunctionDefinition & f,OutputStream & out)4468 void SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
4469 if (const Analysis::Specializations* specializations =
4470 fSpecializationInfo.fSpecializationMap.find(&f.declaration())) {
4471 for (int i = 0; i < specializations->size(); i++) {
4472 this->writeFunctionInstantiation(f, i, &specializations->at(i), out);
4473 }
4474 } else {
4475 this->writeFunctionInstantiation(f,
4476 Analysis::kUnspecialized,
4477 /*specializedParams=*/nullptr,
4478 out);
4479 }
4480 }
4481
writeFunctionInstantiation(const FunctionDefinition & f,Analysis::SpecializationIndex specializationIndex,const Analysis::SpecializedParameters * specializedParams,OutputStream & out)4482 void SPIRVCodeGenerator::writeFunctionInstantiation(
4483 const FunctionDefinition& f,
4484 Analysis::SpecializationIndex specializationIndex,
4485 const Analysis::SpecializedParameters* specializedParams,
4486 OutputStream& out) {
4487 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4488
4489 fVariableBuffer.reset();
4490 fActiveSpecialization = specializedParams;
4491 fActiveSpecializationIndex = specializationIndex;
4492 this->writeFunctionStart(f.declaration(), out);
4493 fCurrentBlock = 0;
4494 this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
4495 StringStream bodyBuffer;
4496 this->writeBlock(f.body()->as<Block>(), bodyBuffer);
4497 fActiveSpecialization = nullptr;
4498 fActiveSpecializationIndex = Analysis::kUnspecialized;
4499 write_stringstream(fVariableBuffer, out);
4500 if (f.declaration().isMain()) {
4501 write_stringstream(fGlobalInitializersBuffer, out);
4502 }
4503 write_stringstream(bodyBuffer, out);
4504 if (fCurrentBlock) {
4505 if (f.declaration().returnType().isVoid()) {
4506 this->writeInstruction(SpvOpReturn, out);
4507 } else {
4508 this->writeInstruction(SpvOpUnreachable, out);
4509 }
4510 }
4511 this->writeInstruction(SpvOpFunctionEnd, out);
4512 this->pruneConditionalOps(conditionalOps);
4513 }
4514
writeLayout(const Layout & layout,SpvId target,Position pos)4515 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, Position pos) {
4516 bool isPushConstant = SkToBool(layout.fFlags & LayoutFlag::kPushConstant);
4517 if (layout.fLocation >= 0) {
4518 this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
4519 fDecorationBuffer);
4520 }
4521 if (layout.fBinding >= 0) {
4522 if (isPushConstant) {
4523 fContext.fErrors->error(pos, "Can't apply 'binding' to push constants");
4524 } else {
4525 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
4526 fDecorationBuffer);
4527 }
4528 }
4529 if (layout.fIndex >= 0) {
4530 this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
4531 fDecorationBuffer);
4532 }
4533 if (layout.fSet >= 0) {
4534 if (isPushConstant) {
4535 fContext.fErrors->error(pos, "Can't apply 'set' to push constants");
4536 } else {
4537 this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
4538 fDecorationBuffer);
4539 }
4540 }
4541 if (layout.fInputAttachmentIndex >= 0) {
4542 this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
4543 layout.fInputAttachmentIndex, fDecorationBuffer);
4544 fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
4545 }
4546 if (layout.fBuiltin >= 0 && (layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
4547 layout.fBuiltin != SK_SECONDARYFRAGCOLOR_BUILTIN)) {
4548 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
4549 fDecorationBuffer);
4550 }
4551 }
4552
writeFieldLayout(const Layout & layout,SpvId target,int member)4553 void SPIRVCodeGenerator::writeFieldLayout(const Layout& layout, SpvId target, int member) {
4554 // 'binding' and 'set' can not be applied to struct members
4555 SkASSERT(layout.fBinding == -1);
4556 SkASSERT(layout.fSet == -1);
4557 if (layout.fLocation >= 0) {
4558 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
4559 layout.fLocation, fDecorationBuffer);
4560 }
4561 if (layout.fIndex >= 0) {
4562 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
4563 layout.fIndex, fDecorationBuffer);
4564 }
4565 if (layout.fInputAttachmentIndex >= 0) {
4566 this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
4567 layout.fInputAttachmentIndex, fDecorationBuffer);
4568 }
4569 if (layout.fBuiltin >= 0) {
4570 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
4571 layout.fBuiltin, fDecorationBuffer);
4572 }
4573 }
4574
memoryLayoutForStorageClass(StorageClass storageClass)4575 MemoryLayout SPIRVCodeGenerator::memoryLayoutForStorageClass(StorageClass storageClass) {
4576 return storageClass == StorageClass::kPushConstant ||
4577 storageClass == StorageClass::kStorageBuffer
4578 ? MemoryLayout(MemoryLayout::Standard::k430)
4579 : fDefaultMemoryLayout;
4580 }
4581
memoryLayoutForVariable(const Variable & v) const4582 MemoryLayout SPIRVCodeGenerator::memoryLayoutForVariable(const Variable& v) const {
4583 bool pushConstant = SkToBool(v.layout().fFlags & LayoutFlag::kPushConstant);
4584 bool buffer = v.modifierFlags().isBuffer();
4585 return pushConstant || buffer ? MemoryLayout(MemoryLayout::Standard::k430)
4586 : fDefaultMemoryLayout;
4587 }
4588
writeInterfaceBlock(const InterfaceBlock & intf,bool appendRTFlip)4589 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip) {
4590 MemoryLayout memoryLayout = this->memoryLayoutForVariable(*intf.var());
4591 SpvId result = this->nextId(nullptr);
4592 const Variable& intfVar = *intf.var();
4593 const Type& type = intfVar.type();
4594 if (!memoryLayout.isSupported(type)) {
4595 fContext.fErrors->error(type.fPosition, "type '" + type.displayName() +
4596 "' is not permitted here");
4597 return this->nextId(nullptr);
4598 }
4599 StorageClass storageClass =
4600 get_storage_class_for_global_variable(intfVar, StorageClass::kFunction);
4601 if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None && appendRTFlip &&
4602 !fWroteRTFlip && type.isStruct()) {
4603 // We can only have one interface block (because we use push_constant and that is limited
4604 // to one per program), so we need to append rtflip to this one rather than synthesize an
4605 // entirely new block when the variable is referenced. And we can't modify the existing
4606 // block, so we instead create a modified copy of it and write that.
4607 SkSpan<const Field> fieldSpan = type.fields();
4608 TArray<Field> fields(fieldSpan.data(), fieldSpan.size());
4609 fields.emplace_back(Position(),
4610 Layout(LayoutFlag::kNone,
4611 /*location=*/-1,
4612 fProgram.fConfig->fSettings.fRTFlipOffset,
4613 /*binding=*/-1,
4614 /*index=*/-1,
4615 /*set=*/-1,
4616 /*builtin=*/-1,
4617 /*inputAttachmentIndex=*/-1),
4618 ModifierFlag::kNone,
4619 SKSL_RTFLIP_NAME,
4620 fContext.fTypes.fFloat2.get());
4621 {
4622 AutoAttachPoolToThread attach(fProgram.fPool.get());
4623 const Type* rtFlipStructType = fProgram.fSymbols->takeOwnershipOfSymbol(
4624 Type::MakeStructType(fContext,
4625 type.fPosition,
4626 type.name(),
4627 std::move(fields),
4628 /*interfaceBlock=*/true));
4629 Variable* modifiedVar = fProgram.fSymbols->takeOwnershipOfSymbol(
4630 Variable::Make(intfVar.fPosition,
4631 intfVar.modifiersPosition(),
4632 intfVar.layout(),
4633 intfVar.modifierFlags(),
4634 rtFlipStructType,
4635 intfVar.name(),
4636 /*mangledName=*/"",
4637 intfVar.isBuiltin(),
4638 intfVar.storage()));
4639 InterfaceBlock modifiedCopy(intf.fPosition, modifiedVar);
4640 result = this->writeInterfaceBlock(modifiedCopy, /*appendRTFlip=*/false);
4641 fProgram.fSymbols->add(fContext, std::make_unique<FieldSymbol>(
4642 Position(), modifiedVar, rtFlipStructType->fields().size() - 1));
4643 }
4644 fVariableMap.set(&intfVar, result);
4645 fWroteRTFlip = true;
4646 return result;
4647 }
4648 SpvId typeId = this->getType(type, kDefaultTypeLayout, memoryLayout);
4649 if (intfVar.layout().fBuiltin == -1) {
4650 // Note: In SPIR-V 1.3, a storage buffer can be declared with the "StorageBuffer"
4651 // storage class and the "Block" decoration and the <1.3 approach we use here ("Uniform"
4652 // storage class and the "BufferBlock" decoration) is deprecated. Since we target SPIR-V
4653 // 1.0, we have to use the deprecated approach which is well supported in Vulkan and
4654 // addresses SkSL use cases (notably SkSL currently doesn't support pointer features that
4655 // would benefit from SPV_KHR_variable_pointers capabilities).
4656 bool isStorageBuffer = intfVar.modifierFlags().isBuffer();
4657 this->writeInstruction(SpvOpDecorate,
4658 typeId,
4659 isStorageBuffer ? SpvDecorationBufferBlock : SpvDecorationBlock,
4660 fDecorationBuffer);
4661 }
4662 SpvId ptrType = this->nextId(nullptr);
4663 this->writeInstruction(SpvOpTypePointer, ptrType,
4664 get_storage_class_spv_id(storageClass), typeId, fConstantBuffer);
4665 this->writeInstruction(SpvOpVariable, ptrType, result,
4666 get_storage_class_spv_id(storageClass), fConstantBuffer);
4667 Layout layout = intfVar.layout();
4668 if ((storageClass == StorageClass::kUniform ||
4669 storageClass == StorageClass::kStorageBuffer) && layout.fSet < 0) {
4670 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
4671 }
4672 this->writeLayout(layout, result, intfVar.fPosition);
4673 fVariableMap.set(&intfVar, result);
4674 return result;
4675 }
4676
4677 // This function determines whether to skip an OpVariable (of pointer type) declaration for
4678 // compile-time constant scalars and vectors which we turn into OpConstant/OpConstantComposite and
4679 // always reference by value.
4680 //
4681 // Accessing a matrix or array member with a dynamic index requires the use of OpAccessChain which
4682 // requires a base operand of pointer type. However, a vector can always be accessed by value using
4683 // OpVectorExtractDynamic (see writeIndexExpression).
4684 //
4685 // This is why we always emit an OpVariable for all non-scalar and non-vector types in case they get
4686 // accessed via a dynamic index.
is_vardecl_compile_time_constant(const VarDeclaration & varDecl)4687 static bool is_vardecl_compile_time_constant(const VarDeclaration& varDecl) {
4688 return varDecl.var()->modifierFlags().isConst() &&
4689 (varDecl.var()->type().isScalar() || varDecl.var()->type().isVector()) &&
4690 (ConstantFolder::GetConstantValueOrNull(*varDecl.value()) ||
4691 Analysis::IsCompileTimeConstant(*varDecl.value()));
4692 }
4693
writeGlobalVarDeclaration(ProgramKind kind,const VarDeclaration & varDecl)4694 bool SPIRVCodeGenerator::writeGlobalVarDeclaration(ProgramKind kind,
4695 const VarDeclaration& varDecl) {
4696 const Variable* var = varDecl.var();
4697 const LayoutFlags backendFlags = var->layout().fFlags & LayoutFlag::kAllBackends;
4698 const LayoutFlags kPermittedBackendFlags =
4699 LayoutFlag::kVulkan | LayoutFlag::kWebGPU | LayoutFlag::kDirect3D;
4700 if (backendFlags & ~kPermittedBackendFlags) {
4701 fContext.fErrors->error(var->fPosition, "incompatible backend flag in SPIR-V codegen");
4702 return false;
4703 }
4704
4705 // If this global variable is a compile-time constant then we'll emit OpConstant or
4706 // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
4707 if (is_vardecl_compile_time_constant(varDecl)) {
4708 return true;
4709 }
4710
4711 StorageClass storageClass =
4712 get_storage_class_for_global_variable(*var, StorageClass::kPrivate);
4713 if (storageClass == StorageClass::kUniform || storageClass == StorageClass::kStorageBuffer) {
4714 // Top-level uniforms are emitted in writeUniformBuffer.
4715 fTopLevelUniforms.push_back(&varDecl);
4716 return true;
4717 }
4718
4719 if (fUseTextureSamplerPairs && var->type().isSampler()) {
4720 if (var->layout().fTexture == -1 || var->layout().fSampler == -1) {
4721 fContext.fErrors->error(var->fPosition, "selected backend requires separate texture "
4722 "and sampler indices");
4723 return false;
4724 }
4725 SkASSERT(storageClass == StorageClass::kUniformConstant);
4726
4727 auto [texture, sampler] = this->synthesizeTextureAndSampler(*var);
4728 this->writeGlobalVar(kind, storageClass, *texture);
4729 this->writeGlobalVar(kind, storageClass, *sampler);
4730
4731 return true;
4732 }
4733
4734 SpvId id = this->writeGlobalVar(kind, storageClass, *var);
4735 if (id != NA && varDecl.value()) {
4736 SkASSERT(!fCurrentBlock);
4737 fCurrentBlock = NA;
4738 SpvId value = this->writeExpression(*varDecl.value(), fGlobalInitializersBuffer);
4739 this->writeOpStore(storageClass, id, value, fGlobalInitializersBuffer);
4740 fCurrentBlock = 0;
4741 }
4742 return true;
4743 }
4744
writeGlobalVar(ProgramKind kind,StorageClass storageClass,const Variable & var)4745 SpvId SPIRVCodeGenerator::writeGlobalVar(ProgramKind kind,
4746 StorageClass storageClass,
4747 const Variable& var) {
4748 Layout layout = var.layout();
4749 const ModifierFlags flags = var.modifierFlags();
4750 const Type* type = &var.type();
4751 switch (layout.fBuiltin) {
4752 case SK_FRAGCOLOR_BUILTIN:
4753 case SK_SECONDARYFRAGCOLOR_BUILTIN:
4754 if (!ProgramConfig::IsFragment(kind)) {
4755 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
4756 return NA;
4757 }
4758 break;
4759
4760 case SK_SAMPLEMASKIN_BUILTIN:
4761 case SK_SAMPLEMASK_BUILTIN:
4762 // SkSL exposes this as a `uint` but SPIR-V, like GLSL, uses an array of signed `uint`
4763 // decorated with SpvBuiltinSampleMask.
4764 type = fSynthetics.addArrayDimension(fContext, type, /*arraySize=*/1);
4765 layout.fBuiltin = SpvBuiltInSampleMask;
4766 break;
4767 }
4768
4769 // Add this global to the variable map.
4770 SpvId id = this->nextId(type);
4771 fVariableMap.set(&var, id);
4772
4773 if (layout.fSet < 0 && storageClass == StorageClass::kUniformConstant) {
4774 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
4775 }
4776
4777 SpvId typeId = this->getPointerType(*type,
4778 layout,
4779 this->memoryLayoutForStorageClass(storageClass),
4780 storageClass);
4781 this->writeInstruction(SpvOpVariable, typeId, id,
4782 get_storage_class_spv_id(storageClass), fConstantBuffer);
4783 this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
4784 this->writeLayout(layout, id, var.fPosition);
4785 if (flags & ModifierFlag::kFlat) {
4786 this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
4787 }
4788 if (flags & ModifierFlag::kNoPerspective) {
4789 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
4790 fDecorationBuffer);
4791 }
4792 if (flags.isWriteOnly()) {
4793 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNonReadable, fDecorationBuffer);
4794 } else if (flags.isReadOnly()) {
4795 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNonWritable, fDecorationBuffer);
4796 }
4797
4798 return id;
4799 }
4800
writeVarDeclaration(const VarDeclaration & varDecl,OutputStream & out)4801 void SPIRVCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, OutputStream& out) {
4802 // If this variable is a compile-time constant then we'll emit OpConstant or
4803 // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
4804 if (is_vardecl_compile_time_constant(varDecl)) {
4805 return;
4806 }
4807
4808 const Variable* var = varDecl.var();
4809 SpvId id = this->nextId(&var->type());
4810 fVariableMap.set(var, id);
4811 SpvId type = this->getPointerType(var->type(), StorageClass::kFunction);
4812 this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
4813 this->writeInstruction(SpvOpName, id, var->name(), fNameBuffer);
4814 if (varDecl.value()) {
4815 SpvId value = this->writeExpression(*varDecl.value(), out);
4816 this->writeOpStore(StorageClass::kFunction, id, value, out);
4817 }
4818 }
4819
writeStatement(const Statement & s,OutputStream & out)4820 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
4821 switch (s.kind()) {
4822 case Statement::Kind::kNop:
4823 break;
4824 case Statement::Kind::kBlock:
4825 this->writeBlock(s.as<Block>(), out);
4826 break;
4827 case Statement::Kind::kExpression:
4828 this->writeExpression(*s.as<ExpressionStatement>().expression(), out);
4829 break;
4830 case Statement::Kind::kReturn:
4831 this->writeReturnStatement(s.as<ReturnStatement>(), out);
4832 break;
4833 case Statement::Kind::kVarDeclaration:
4834 this->writeVarDeclaration(s.as<VarDeclaration>(), out);
4835 break;
4836 case Statement::Kind::kIf:
4837 this->writeIfStatement(s.as<IfStatement>(), out);
4838 break;
4839 case Statement::Kind::kFor:
4840 this->writeForStatement(s.as<ForStatement>(), out);
4841 break;
4842 case Statement::Kind::kDo:
4843 this->writeDoStatement(s.as<DoStatement>(), out);
4844 break;
4845 case Statement::Kind::kSwitch:
4846 this->writeSwitchStatement(s.as<SwitchStatement>(), out);
4847 break;
4848 case Statement::Kind::kBreak:
4849 this->writeInstruction(SpvOpBranch, fBreakTarget.back(), out);
4850 break;
4851 case Statement::Kind::kContinue:
4852 this->writeInstruction(SpvOpBranch, fContinueTarget.back(), out);
4853 break;
4854 case Statement::Kind::kDiscard:
4855 this->writeInstruction(SpvOpKill, out);
4856 break;
4857 default:
4858 SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
4859 break;
4860 }
4861 }
4862
writeBlock(const Block & b,OutputStream & out)4863 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
4864 for (const std::unique_ptr<Statement>& stmt : b.children()) {
4865 this->writeStatement(*stmt, out);
4866 }
4867 }
4868
getConditionalOpCounts()4869 SPIRVCodeGenerator::ConditionalOpCounts SPIRVCodeGenerator::getConditionalOpCounts() {
4870 return {fReachableOps.size(), fStoreOps.size()};
4871 }
4872
pruneConditionalOps(ConditionalOpCounts ops)4873 void SPIRVCodeGenerator::pruneConditionalOps(ConditionalOpCounts ops) {
4874 // Remove ops which are no longer reachable.
4875 while (fReachableOps.size() > ops.numReachableOps) {
4876 SpvId prunableSpvId = fReachableOps.back();
4877 const Instruction* prunableOp = fSpvIdCache.find(prunableSpvId);
4878
4879 if (prunableOp) {
4880 fOpCache.remove(*prunableOp);
4881 fSpvIdCache.remove(prunableSpvId);
4882 } else {
4883 SkDEBUGFAIL("reachable-op list contains unrecognized SpvId");
4884 }
4885
4886 fReachableOps.pop_back();
4887 }
4888
4889 // Remove any cached stores that occurred during the conditional block.
4890 while (fStoreOps.size() > ops.numStoreOps) {
4891 if (fStoreCache.find(fStoreOps.back())) {
4892 fStoreCache.remove(fStoreOps.back());
4893 }
4894 fStoreOps.pop_back();
4895 }
4896 }
4897
writeIfStatement(const IfStatement & stmt,OutputStream & out)4898 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
4899 SpvId test = this->writeExpression(*stmt.test(), out);
4900 SpvId ifTrue = this->nextId(nullptr);
4901 SpvId ifFalse = this->nextId(nullptr);
4902
4903 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4904
4905 if (stmt.ifFalse()) {
4906 SpvId end = this->nextId(nullptr);
4907 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4908 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
4909 this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
4910 this->writeStatement(*stmt.ifTrue(), out);
4911 if (fCurrentBlock) {
4912 this->writeInstruction(SpvOpBranch, end, out);
4913 }
4914 this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
4915 this->writeStatement(*stmt.ifFalse(), out);
4916 if (fCurrentBlock) {
4917 this->writeInstruction(SpvOpBranch, end, out);
4918 }
4919 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4920 } else {
4921 this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
4922 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
4923 this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
4924 this->writeStatement(*stmt.ifTrue(), out);
4925 if (fCurrentBlock) {
4926 this->writeInstruction(SpvOpBranch, ifFalse, out);
4927 }
4928 this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
4929 }
4930 }
4931
writeForStatement(const ForStatement & f,OutputStream & out)4932 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
4933 if (f.initializer()) {
4934 this->writeStatement(*f.initializer(), out);
4935 }
4936
4937 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4938
4939 // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
4940 // in the context of linear straight-line execution. If we wanted to be more clever, we could
4941 // only invalidate store cache entries for variables affected by the loop body, but for now we
4942 // simply clear the entire cache whenever branching occurs.
4943 SpvId header = this->nextId(nullptr);
4944 SpvId start = this->nextId(nullptr);
4945 SpvId body = this->nextId(nullptr);
4946 SpvId next = this->nextId(nullptr);
4947 fContinueTarget.push_back(next);
4948 SpvId end = this->nextId(nullptr);
4949 fBreakTarget.push_back(end);
4950 this->writeInstruction(SpvOpBranch, header, out);
4951 this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
4952 this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
4953 this->writeInstruction(SpvOpBranch, start, out);
4954 this->writeLabel(start, kBranchIsOnPreviousLine, out);
4955 if (f.test()) {
4956 SpvId test = this->writeExpression(*f.test(), out);
4957 this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
4958 } else {
4959 this->writeInstruction(SpvOpBranch, body, out);
4960 }
4961 this->writeLabel(body, kBranchIsOnPreviousLine, out);
4962 this->writeStatement(*f.statement(), out);
4963 if (fCurrentBlock) {
4964 this->writeInstruction(SpvOpBranch, next, out);
4965 }
4966 this->writeLabel(next, kBranchIsAbove, conditionalOps, out);
4967 if (f.next()) {
4968 this->writeExpression(*f.next(), out);
4969 }
4970 this->writeInstruction(SpvOpBranch, header, out);
4971 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4972 fBreakTarget.pop_back();
4973 fContinueTarget.pop_back();
4974 }
4975
writeDoStatement(const DoStatement & d,OutputStream & out)4976 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
4977 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4978
4979 // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
4980 // in the context of linear straight-line execution. If we wanted to be more clever, we could
4981 // only invalidate store cache entries for variables affected by the loop body, but for now we
4982 // simply clear the entire cache whenever branching occurs.
4983 SpvId header = this->nextId(nullptr);
4984 SpvId start = this->nextId(nullptr);
4985 SpvId next = this->nextId(nullptr);
4986 SpvId continueTarget = this->nextId(nullptr);
4987 fContinueTarget.push_back(continueTarget);
4988 SpvId end = this->nextId(nullptr);
4989 fBreakTarget.push_back(end);
4990 this->writeInstruction(SpvOpBranch, header, out);
4991 this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
4992 this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
4993 this->writeInstruction(SpvOpBranch, start, out);
4994 this->writeLabel(start, kBranchIsOnPreviousLine, out);
4995 this->writeStatement(*d.statement(), out);
4996 if (fCurrentBlock) {
4997 this->writeInstruction(SpvOpBranch, next, out);
4998 this->writeLabel(next, kBranchIsOnPreviousLine, out);
4999 this->writeInstruction(SpvOpBranch, continueTarget, out);
5000 }
5001 this->writeLabel(continueTarget, kBranchIsAbove, conditionalOps, out);
5002 SpvId test = this->writeExpression(*d.test(), out);
5003 this->writeInstruction(SpvOpBranchConditional, test, header, end, out);
5004 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
5005 fBreakTarget.pop_back();
5006 fContinueTarget.pop_back();
5007 }
5008
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)5009 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
5010 SpvId value = this->writeExpression(*s.value(), out);
5011
5012 ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
5013
5014 // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
5015 // in the context of linear straight-line execution. If we wanted to be more clever, we could
5016 // only invalidate store cache entries for variables affected by the switch body, but for now we
5017 // simply clear the entire cache whenever branching occurs.
5018 TArray<SpvId> labels;
5019 SpvId end = this->nextId(nullptr);
5020 SpvId defaultLabel = end;
5021 fBreakTarget.push_back(end);
5022 int size = 3;
5023 const StatementArray& cases = s.cases();
5024 for (const std::unique_ptr<Statement>& stmt : cases) {
5025 const SwitchCase& c = stmt->as<SwitchCase>();
5026 SpvId label = this->nextId(nullptr);
5027 labels.push_back(label);
5028 if (!c.isDefault()) {
5029 size += 2;
5030 } else {
5031 defaultLabel = label;
5032 }
5033 }
5034
5035 // We should have exactly one label for each case.
5036 SkASSERT(labels.size() == cases.size());
5037
5038 // Collapse adjacent switch-cases into one; that is, reduce `case 1: case 2: case 3:` into a
5039 // single OpLabel. The Tint SPIR-V reader does not support switch-case fallthrough, but it
5040 // does support multiple switch-cases branching to the same label.
5041 SkBitSet caseIsCollapsed(cases.size());
5042 for (int index = cases.size() - 2; index >= 0; index--) {
5043 if (cases[index]->as<SwitchCase>().statement()->isEmpty()) {
5044 caseIsCollapsed.set(index);
5045 labels[index] = labels[index + 1];
5046 }
5047 }
5048
5049 labels.push_back(end);
5050
5051 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
5052 this->writeOpCode(SpvOpSwitch, size, out);
5053 this->writeWord(value, out);
5054 this->writeWord(defaultLabel, out);
5055 for (int i = 0; i < cases.size(); ++i) {
5056 const SwitchCase& c = cases[i]->as<SwitchCase>();
5057 if (c.isDefault()) {
5058 continue;
5059 }
5060 this->writeWord(c.value(), out);
5061 this->writeWord(labels[i], out);
5062 }
5063 for (int i = 0; i < cases.size(); ++i) {
5064 if (caseIsCollapsed.test(i)) {
5065 continue;
5066 }
5067 const SwitchCase& c = cases[i]->as<SwitchCase>();
5068 if (i == 0) {
5069 this->writeLabel(labels[i], kBranchIsOnPreviousLine, out);
5070 } else {
5071 this->writeLabel(labels[i], kBranchIsAbove, conditionalOps, out);
5072 }
5073 this->writeStatement(*c.statement(), out);
5074 if (fCurrentBlock) {
5075 this->writeInstruction(SpvOpBranch, labels[i + 1], out);
5076 }
5077 }
5078 this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
5079 fBreakTarget.pop_back();
5080 }
5081
writeReturnStatement(const ReturnStatement & r,OutputStream & out)5082 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
5083 if (r.expression()) {
5084 this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.expression(), out),
5085 out);
5086 } else {
5087 this->writeInstruction(SpvOpReturn, out);
5088 }
5089 }
5090
5091 // Given any function, returns the top-level symbol table (OUTSIDE of the function's scope).
get_top_level_symbol_table(const FunctionDeclaration & anyFunc)5092 static SymbolTable* get_top_level_symbol_table(const FunctionDeclaration& anyFunc) {
5093 return anyFunc.definition()->body()->as<Block>().symbolTable()->fParent;
5094 }
5095
writeEntrypointAdapter(const FunctionDeclaration & main)5096 SPIRVCodeGenerator::EntrypointAdapter SPIRVCodeGenerator::writeEntrypointAdapter(
5097 const FunctionDeclaration& main) {
5098 // Our goal is to synthesize a tiny helper function which looks like this:
5099 // void _entrypoint() { sk_FragColor = main(); }
5100
5101 // Fish a symbol table out of main().
5102 SymbolTable* symbolTable = get_top_level_symbol_table(main);
5103
5104 // Get `sk_FragColor` as a writable reference.
5105 const Symbol* skFragColorSymbol = symbolTable->find("sk_FragColor");
5106 SkASSERT(skFragColorSymbol);
5107 const Variable& skFragColorVar = skFragColorSymbol->as<Variable>();
5108 auto skFragColorRef = std::make_unique<VariableReference>(Position(), &skFragColorVar,
5109 VariableReference::RefKind::kWrite);
5110
5111 // TODO get secondary frag color as one as well?
5112
5113 // Synthesize a call to the `main()` function.
5114 if (!main.returnType().matches(skFragColorRef->type())) {
5115 fContext.fErrors->error(main.fPosition, "SPIR-V does not support returning '" +
5116 main.returnType().description() + "' from main()");
5117 return {};
5118 }
5119 ExpressionArray args;
5120 if (main.parameters().size() == 1) {
5121 if (!main.parameters()[0]->type().matches(*fContext.fTypes.fFloat2)) {
5122 fContext.fErrors->error(main.fPosition,
5123 "SPIR-V does not support parameter of type '" +
5124 main.parameters()[0]->type().description() + "' to main()");
5125 return {};
5126 }
5127 double kZero[2] = {0.0, 0.0};
5128 args.push_back(ConstructorCompound::MakeFromConstants(fContext, Position{},
5129 *fContext.fTypes.fFloat2, kZero));
5130 }
5131 auto callMainFn = FunctionCall::Make(fContext, Position(), &main.returnType(),
5132 main, std::move(args));
5133
5134 // Synthesize `skFragColor = main()` as a BinaryExpression.
5135 auto assignmentStmt = std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
5136 Position(),
5137 std::move(skFragColorRef),
5138 Operator::Kind::EQ,
5139 std::move(callMainFn),
5140 &main.returnType()));
5141
5142 // Function bodies are always wrapped in a Block.
5143 StatementArray entrypointStmts;
5144 entrypointStmts.push_back(std::move(assignmentStmt));
5145 auto entrypointBlock = Block::Make(Position(), std::move(entrypointStmts),
5146 Block::Kind::kBracedScope, /*symbols=*/nullptr);
5147 // Declare an entrypoint function.
5148 EntrypointAdapter adapter;
5149 adapter.entrypointDecl =
5150 std::make_unique<FunctionDeclaration>(fContext,
5151 Position(),
5152 ModifierFlag::kNone,
5153 "_entrypoint",
5154 /*parameters=*/TArray<Variable*>{},
5155 /*returnType=*/fContext.fTypes.fVoid.get(),
5156 kNotIntrinsic);
5157 // Define it.
5158 adapter.entrypointDef = FunctionDefinition::Convert(fContext,
5159 Position(),
5160 *adapter.entrypointDecl,
5161 std::move(entrypointBlock));
5162
5163 adapter.entrypointDecl->setDefinition(adapter.entrypointDef.get());
5164 return adapter;
5165 }
5166
writeUniformBuffer(SymbolTable * topLevelSymbolTable)5167 void SPIRVCodeGenerator::writeUniformBuffer(SymbolTable* topLevelSymbolTable) {
5168 SkASSERT(!fTopLevelUniforms.empty());
5169 static constexpr char kUniformBufferName[] = "_UniformBuffer";
5170
5171 // Convert the list of top-level uniforms into a matching struct named _UniformBuffer, and build
5172 // a lookup table of variables to UniformBuffer field indices.
5173 TArray<Field> fields;
5174 fields.reserve_exact(fTopLevelUniforms.size());
5175 for (const VarDeclaration* topLevelUniform : fTopLevelUniforms) {
5176 const Variable* var = topLevelUniform->var();
5177 fTopLevelUniformMap.set(var, (int)fields.size());
5178 ModifierFlags flags = var->modifierFlags() & ~ModifierFlag::kUniform;
5179 fields.emplace_back(var->fPosition, var->layout(), flags, var->name(), &var->type());
5180 }
5181 fUniformBuffer.fStruct = Type::MakeStructType(fContext,
5182 Position(),
5183 kUniformBufferName,
5184 std::move(fields),
5185 /*interfaceBlock=*/true);
5186
5187 // Create a global variable to contain this struct.
5188 Layout layout;
5189 layout.fBinding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
5190 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
5191
5192 fUniformBuffer.fInnerVariable = Variable::Make(/*pos=*/Position(),
5193 /*modifiersPosition=*/Position(),
5194 layout,
5195 ModifierFlag::kUniform,
5196 fUniformBuffer.fStruct.get(),
5197 kUniformBufferName,
5198 /*mangledName=*/"",
5199 /*builtin=*/false,
5200 Variable::Storage::kGlobal);
5201
5202 // Create an interface block object for this global variable.
5203 fUniformBuffer.fInterfaceBlock =
5204 std::make_unique<InterfaceBlock>(Position(), fUniformBuffer.fInnerVariable.get());
5205
5206 // Generate an interface block and hold onto its ID.
5207 fUniformBufferId = this->writeInterfaceBlock(*fUniformBuffer.fInterfaceBlock);
5208 }
5209
addRTFlipUniform(Position pos)5210 void SPIRVCodeGenerator::addRTFlipUniform(Position pos) {
5211 SkASSERT(!fProgram.fConfig->fSettings.fForceNoRTFlip);
5212
5213 if (fWroteRTFlip) {
5214 return;
5215 }
5216 // Flip variable hasn't been written yet. This means we don't have an existing
5217 // interface block, so we're free to just synthesize one.
5218 fWroteRTFlip = true;
5219 TArray<Field> fields;
5220 if (fProgram.fConfig->fSettings.fRTFlipOffset < 0) {
5221 fContext.fErrors->error(pos, "RTFlipOffset is negative");
5222 }
5223 fields.emplace_back(pos,
5224 Layout(LayoutFlag::kNone,
5225 /*location=*/-1,
5226 fProgram.fConfig->fSettings.fRTFlipOffset,
5227 /*binding=*/-1,
5228 /*index=*/-1,
5229 /*set=*/-1,
5230 /*builtin=*/-1,
5231 /*inputAttachmentIndex=*/-1),
5232 ModifierFlag::kNone,
5233 SKSL_RTFLIP_NAME,
5234 fContext.fTypes.fFloat2.get());
5235 std::string_view name = "sksl_synthetic_uniforms";
5236 const Type* intfStruct = fSynthetics.takeOwnershipOfSymbol(Type::MakeStructType(
5237 fContext, Position(), name, std::move(fields), /*interfaceBlock=*/true));
5238 bool usePushConstants = fProgram.fConfig->fSettings.fUsePushConstants;
5239 int binding = -1, set = -1;
5240 if (!usePushConstants) {
5241 binding = fProgram.fConfig->fSettings.fRTFlipBinding;
5242 if (binding == -1) {
5243 fContext.fErrors->error(pos, "layout(binding=...) is required in SPIR-V");
5244 }
5245 set = fProgram.fConfig->fSettings.fRTFlipSet;
5246 if (set == -1) {
5247 fContext.fErrors->error(pos, "layout(set=...) is required in SPIR-V");
5248 }
5249 }
5250 Layout layout(/*flags=*/usePushConstants ? LayoutFlag::kPushConstant : LayoutFlag::kNone,
5251 /*location=*/-1,
5252 /*offset=*/-1,
5253 binding,
5254 /*index=*/-1,
5255 set,
5256 /*builtin=*/-1,
5257 /*inputAttachmentIndex=*/-1);
5258 Variable* intfVar =
5259 fSynthetics.takeOwnershipOfSymbol(Variable::Make(/*pos=*/Position(),
5260 /*modifiersPosition=*/Position(),
5261 layout,
5262 ModifierFlag::kUniform,
5263 intfStruct,
5264 name,
5265 /*mangledName=*/"",
5266 /*builtin=*/false,
5267 Variable::Storage::kGlobal));
5268 {
5269 AutoAttachPoolToThread attach(fProgram.fPool.get());
5270 fProgram.fSymbols->add(fContext,
5271 std::make_unique<FieldSymbol>(Position(), intfVar, /*field=*/0));
5272 }
5273 InterfaceBlock intf(Position(), intfVar);
5274 this->writeInterfaceBlock(intf, false);
5275 }
5276
synthesizeTextureAndSampler(const Variable & combinedSampler)5277 std::tuple<const Variable*, const Variable*> SPIRVCodeGenerator::synthesizeTextureAndSampler(
5278 const Variable& combinedSampler) {
5279 SkASSERT(fUseTextureSamplerPairs);
5280 SkASSERT(combinedSampler.type().typeKind() == Type::TypeKind::kSampler);
5281
5282 if (std::unique_ptr<SynthesizedTextureSamplerPair>* existingData =
5283 fSynthesizedSamplerMap.find(&combinedSampler)) {
5284 return {(*existingData)->fTexture.get(), (*existingData)->fSampler.get()};
5285 }
5286
5287 auto data = std::make_unique<SynthesizedTextureSamplerPair>();
5288
5289 Layout texLayout = combinedSampler.layout();
5290 texLayout.fBinding = texLayout.fTexture;
5291 data->fTextureName = std::string(combinedSampler.name()) + "_texture";
5292
5293 auto texture = Variable::Make(/*pos=*/Position(),
5294 /*modifiersPosition=*/Position(),
5295 texLayout,
5296 combinedSampler.modifierFlags(),
5297 &combinedSampler.type().textureType(),
5298 data->fTextureName,
5299 /*mangledName=*/"",
5300 /*builtin=*/false,
5301 Variable::Storage::kGlobal);
5302
5303 Layout samplerLayout = combinedSampler.layout();
5304 samplerLayout.fBinding = samplerLayout.fSampler;
5305 samplerLayout.fFlags &= ~LayoutFlag::kAllPixelFormats;
5306 data->fSamplerName = std::string(combinedSampler.name()) + "_sampler";
5307
5308 auto sampler = Variable::Make(/*pos=*/Position(),
5309 /*modifiersPosition=*/Position(),
5310 samplerLayout,
5311 combinedSampler.modifierFlags(),
5312 fContext.fTypes.fSampler.get(),
5313 data->fSamplerName,
5314 /*mangledName=*/"",
5315 /*builtin=*/false,
5316 Variable::Storage::kGlobal);
5317
5318 const Variable* t = texture.get();
5319 const Variable* s = sampler.get();
5320 data->fTexture = std::move(texture);
5321 data->fSampler = std::move(sampler);
5322 fSynthesizedSamplerMap.set(&combinedSampler, std::move(data));
5323
5324 return {t, s};
5325 }
5326
writeInstructions(const Program & program,OutputStream & out)5327 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
5328 Analysis::FindFunctionsToSpecialize(program, &fSpecializationInfo, [](const Variable& param) {
5329 return param.type().isSampler() || param.type().isUnsizedArray();
5330 });
5331
5332 fGLSLExtendedInstructions = this->nextId(nullptr);
5333 StringStream body;
5334
5335 // Do an initial pass over the program elements to establish some baseline info.
5336 const FunctionDeclaration* main = nullptr;
5337 int localSizeX = 1, localSizeY = 1, localSizeZ = 1;
5338 Position combinedSamplerPos;
5339 Position separateSamplerPos;
5340 for (const ProgramElement* e : program.elements()) {
5341 switch (e->kind()) {
5342 case ProgramElement::Kind::kFunction: {
5343 // Assign SpvIds to functions.
5344 const FunctionDefinition& funcDef = e->as<FunctionDefinition>();
5345 const FunctionDeclaration& funcDecl = funcDef.declaration();
5346 if (const Analysis::Specializations* specializations =
5347 fSpecializationInfo.fSpecializationMap.find(&funcDecl)) {
5348 for (int i = 0; i < specializations->size(); i++) {
5349 fFunctionMap.set({&funcDecl, i}, this->nextId(nullptr));
5350 }
5351 } else {
5352 fFunctionMap.set({&funcDecl, Analysis::kUnspecialized}, this->nextId(nullptr));
5353 }
5354 if (funcDecl.isMain()) {
5355 main = &funcDecl;
5356 }
5357 break;
5358 }
5359 case ProgramElement::Kind::kGlobalVar: {
5360 // Look for sampler variables and determine whether or not this program uses
5361 // combined samplers or separate samplers. The layout backend will be marked as
5362 // WebGPU for separate samplers, or Vulkan for combined samplers.
5363 const GlobalVarDeclaration& decl = e->as<GlobalVarDeclaration>();
5364 const Variable& var = *decl.varDeclaration().var();
5365 if (var.type().isSampler()) {
5366 if (var.layout().fFlags & LayoutFlag::kVulkan) {
5367 combinedSamplerPos = decl.position();
5368 }
5369 if (var.layout().fFlags & (LayoutFlag::kWebGPU | LayoutFlag::kDirect3D)) {
5370 separateSamplerPos = decl.position();
5371 }
5372 }
5373 break;
5374 }
5375 case ProgramElement::Kind::kModifiers: {
5376 // If this is a compute program, collect the local-size values. Dimensions that are
5377 // not present will be assigned a value of 1.
5378 if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5379 const ModifiersDeclaration& modifiers = e->as<ModifiersDeclaration>();
5380 if (modifiers.layout().fLocalSizeX >= 0) {
5381 localSizeX = modifiers.layout().fLocalSizeX;
5382 }
5383 if (modifiers.layout().fLocalSizeY >= 0) {
5384 localSizeY = modifiers.layout().fLocalSizeY;
5385 }
5386 if (modifiers.layout().fLocalSizeZ >= 0) {
5387 localSizeZ = modifiers.layout().fLocalSizeZ;
5388 }
5389 }
5390 break;
5391 }
5392 default:
5393 break;
5394 }
5395 }
5396
5397 // Make sure we have a main() function.
5398 if (!main) {
5399 fContext.fErrors->error(Position(), "program does not contain a main() function");
5400 return;
5401 }
5402 // Make sure our program's sampler usage is consistent.
5403 if (combinedSamplerPos.valid() && separateSamplerPos.valid()) {
5404 fContext.fErrors->error(Position(), "programs cannot contain a mixture of sampler types");
5405 fContext.fErrors->error(combinedSamplerPos, "combined sampler found here:");
5406 fContext.fErrors->error(separateSamplerPos, "separate sampler found here:");
5407 return;
5408 }
5409 fUseTextureSamplerPairs = separateSamplerPos.valid();
5410
5411 // Emit interface blocks.
5412 std::set<SpvId> interfaceVars;
5413 for (const ProgramElement* e : program.elements()) {
5414 if (e->is<InterfaceBlock>()) {
5415 const InterfaceBlock& intf = e->as<InterfaceBlock>();
5416 SpvId id = this->writeInterfaceBlock(intf);
5417
5418 if ((intf.var()->modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut)) &&
5419 intf.var()->layout().fBuiltin == -1) {
5420 interfaceVars.insert(id);
5421 }
5422 }
5423 }
5424 // If MustDeclareFragmentFrontFacing is set, the front-facing flag (sk_Clockwise) needs to be
5425 // explicitly declared in the output, whether or not the program explicitly references it.
5426 // However, if the program naturally declares it, we don't want to include it a second time;
5427 // we keep track of the real global variable declarations to see if sk_Clockwise is emitted.
5428 const VarDeclaration* missingClockwiseDecl = nullptr;
5429 if (fCaps.fMustDeclareFragmentFrontFacing) {
5430 if (const Symbol* clockwise = program.fSymbols->findBuiltinSymbol("sk_Clockwise")) {
5431 missingClockwiseDecl = clockwise->as<Variable>().varDeclaration();
5432 }
5433 }
5434 // Emit global variable declarations.
5435 for (const ProgramElement* e : program.elements()) {
5436 if (e->is<GlobalVarDeclaration>()) {
5437 const VarDeclaration& decl = e->as<GlobalVarDeclaration>().varDeclaration();
5438 if (!this->writeGlobalVarDeclaration(program.fConfig->fKind, decl)) {
5439 return;
5440 }
5441 if (missingClockwiseDecl == &decl) {
5442 // We emitted an sk_Clockwise declaration naturally, so we don't need a workaround.
5443 missingClockwiseDecl = nullptr;
5444 }
5445 }
5446 }
5447 // All the global variables have been declared. If sk_Clockwise was not naturally included in
5448 // the output, but MustDeclareFragmentFrontFacing was set, we need to bodge it in ourselves.
5449 if (missingClockwiseDecl) {
5450 if (!this->writeGlobalVarDeclaration(program.fConfig->fKind, *missingClockwiseDecl)) {
5451 return;
5452 }
5453 missingClockwiseDecl = nullptr;
5454 }
5455 // Emit top-level uniforms into a dedicated uniform buffer.
5456 if (!fTopLevelUniforms.empty()) {
5457 this->writeUniformBuffer(get_top_level_symbol_table(*main));
5458 }
5459 // If main() returns a half4, synthesize a tiny entrypoint function which invokes the real
5460 // main() and stores the result into sk_FragColor.
5461 EntrypointAdapter adapter;
5462 if (main->returnType().matches(*fContext.fTypes.fHalf4)) {
5463 adapter = this->writeEntrypointAdapter(*main);
5464 if (adapter.entrypointDecl) {
5465 fFunctionMap.set({adapter.entrypointDecl.get(), Analysis::kUnspecialized},
5466 this->nextId(nullptr));
5467 this->writeFunction(*adapter.entrypointDef, body);
5468 main = adapter.entrypointDecl.get();
5469 }
5470 }
5471 // Emit all the functions.
5472 for (const ProgramElement* e : program.elements()) {
5473 if (e->is<FunctionDefinition>()) {
5474 this->writeFunction(e->as<FunctionDefinition>(), body);
5475 }
5476 }
5477 // Add global in/out variables to the list of interface variables.
5478 for (const auto& [var, spvId] : fVariableMap) {
5479 if (var->storage() == Variable::Storage::kGlobal &&
5480 (var->modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut))) {
5481 interfaceVars.insert(spvId);
5482 }
5483 }
5484 this->writeCapabilities(out);
5485 this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
5486 this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
5487 this->writeOpCode(SpvOpEntryPoint,
5488 (SpvId)(3 + (main->name().length() + 4) / 4) + (int32_t)interfaceVars.size(),
5489 out);
5490 if (ProgramConfig::IsVertex(program.fConfig->fKind)) {
5491 this->writeWord(SpvExecutionModelVertex, out);
5492 } else if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
5493 this->writeWord(SpvExecutionModelFragment, out);
5494 } else if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5495 this->writeWord(SpvExecutionModelGLCompute, out);
5496 } else {
5497 SK_ABORT("cannot write this kind of program to SPIR-V\n");
5498 }
5499 const Analysis::SpecializedFunctionKey mainKey{main, Analysis::kUnspecialized};
5500 SpvId entryPoint = fFunctionMap[mainKey];
5501 this->writeWord(entryPoint, out);
5502 this->writeString(main->name(), out);
5503 for (int var : interfaceVars) {
5504 this->writeWord(var, out);
5505 }
5506 if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
5507 this->writeInstruction(SpvOpExecutionMode,
5508 fFunctionMap[mainKey],
5509 SpvExecutionModeOriginUpperLeft,
5510 out);
5511 } else if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5512 this->writeInstruction(SpvOpExecutionMode,
5513 fFunctionMap[mainKey],
5514 SpvExecutionModeLocalSize,
5515 localSizeX, localSizeY, localSizeZ,
5516 out);
5517 }
5518 for (const ProgramElement* e : program.elements()) {
5519 if (e->is<Extension>()) {
5520 this->writeInstruction(SpvOpSourceExtension, e->as<Extension>().name(), out);
5521 }
5522 }
5523
5524 write_stringstream(fNameBuffer, out);
5525 write_stringstream(fDecorationBuffer, out);
5526 write_stringstream(fConstantBuffer, out);
5527 write_stringstream(body, out);
5528 }
5529
generateCode()5530 bool SPIRVCodeGenerator::generateCode() {
5531 SkASSERT(!fContext.fErrors->errorCount());
5532 this->writeWord(SpvMagicNumber, *fOut);
5533 this->writeWord(SpvVersion, *fOut);
5534 this->writeWord(SKSL_MAGIC, *fOut);
5535 StringStream buffer;
5536 this->writeInstructions(fProgram, buffer);
5537 this->writeWord(fIdCount, *fOut);
5538 this->writeWord(0, *fOut); // reserved, always zero
5539 write_stringstream(buffer, *fOut);
5540 return fContext.fErrors->errorCount() == 0;
5541 }
5542
ToSPIRV(Program & program,const ShaderCaps * caps,OutputStream & out,ValidateSPIRVProc validateSPIRV)5543 bool ToSPIRV(Program& program,
5544 const ShaderCaps* caps,
5545 OutputStream& out,
5546 ValidateSPIRVProc validateSPIRV) {
5547 TRACE_EVENT0("skia.shaders", "SkSL::ToSPIRV");
5548 SkASSERT(caps != nullptr);
5549
5550 program.fContext->fErrors->setSource(*program.fSource);
5551 bool result;
5552 if (validateSPIRV) {
5553 StringStream buffer;
5554 SPIRVCodeGenerator cg(program.fContext.get(), caps, &program, &buffer);
5555 result = cg.generateCode();
5556
5557 if (result && program.fConfig->fSettings.fValidateSPIRV) {
5558 std::string_view binary = buffer.str();
5559 result = validateSPIRV(*program.fContext->fErrors, binary);
5560 out.write(binary.data(), binary.size());
5561 }
5562 } else {
5563 SPIRVCodeGenerator cg(program.fContext.get(), caps, &program, &out);
5564 result = cg.generateCode();
5565 }
5566 program.fContext->fErrors->setSource(std::string_view());
5567
5568 return result;
5569 }
5570
ToSPIRV(Program & program,const ShaderCaps * caps,std::string * out,ValidateSPIRVProc validateSPIRV)5571 bool ToSPIRV(Program& program,
5572 const ShaderCaps* caps,
5573 std::string* out,
5574 ValidateSPIRVProc validateSPIRV) {
5575 StringStream buffer;
5576 if (!ToSPIRV(program, caps, buffer, validateSPIRV)) {
5577 return false;
5578 }
5579 *out = buffer.str();
5580 return true;
5581 }
5582
5583 } // namespace SkSL
5584