1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7 #include "src/sksl/codegen/SkSLMetalCodeGenerator.h"
8
9 #include "include/core/SkSpan.h"
10 #include "include/core/SkTypes.h"
11 #include "include/private/base/SkTArray.h"
12 #include "include/private/base/SkTo.h"
13 #include "src/base/SkEnumBitMask.h"
14 #include "src/base/SkScopeExit.h"
15 #include "src/core/SkTHash.h"
16 #include "src/core/SkTraceEvent.h"
17 #include "src/sksl/SkSLAnalysis.h"
18 #include "src/sksl/SkSLBuiltinTypes.h"
19 #include "src/sksl/SkSLCompiler.h"
20 #include "src/sksl/SkSLContext.h"
21 #include "src/sksl/SkSLDefines.h"
22 #include "src/sksl/SkSLErrorReporter.h"
23 #include "src/sksl/SkSLIntrinsicList.h"
24 #include "src/sksl/SkSLMemoryLayout.h"
25 #include "src/sksl/SkSLOperator.h"
26 #include "src/sksl/SkSLOutputStream.h"
27 #include "src/sksl/SkSLPosition.h"
28 #include "src/sksl/SkSLProgramSettings.h"
29 #include "src/sksl/SkSLString.h"
30 #include "src/sksl/SkSLStringStream.h"
31 #include "src/sksl/SkSLUtil.h"
32 #include "src/sksl/analysis/SkSLProgramVisitor.h"
33 #include "src/sksl/codegen/SkSLCodeGenTypes.h"
34 #include "src/sksl/codegen/SkSLCodeGenerator.h"
35 #include "src/sksl/ir/SkSLBinaryExpression.h"
36 #include "src/sksl/ir/SkSLBlock.h"
37 #include "src/sksl/ir/SkSLConstructor.h"
38 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
39 #include "src/sksl/ir/SkSLConstructorCompound.h"
40 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
41 #include "src/sksl/ir/SkSLDoStatement.h"
42 #include "src/sksl/ir/SkSLExpression.h"
43 #include "src/sksl/ir/SkSLExpressionStatement.h"
44 #include "src/sksl/ir/SkSLExtension.h"
45 #include "src/sksl/ir/SkSLFieldAccess.h"
46 #include "src/sksl/ir/SkSLForStatement.h"
47 #include "src/sksl/ir/SkSLFunctionCall.h"
48 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
49 #include "src/sksl/ir/SkSLFunctionDefinition.h"
50 #include "src/sksl/ir/SkSLFunctionPrototype.h"
51 #include "src/sksl/ir/SkSLIRHelpers.h"
52 #include "src/sksl/ir/SkSLIRNode.h"
53 #include "src/sksl/ir/SkSLIfStatement.h"
54 #include "src/sksl/ir/SkSLIndexExpression.h"
55 #include "src/sksl/ir/SkSLInterfaceBlock.h"
56 #include "src/sksl/ir/SkSLLayout.h"
57 #include "src/sksl/ir/SkSLLiteral.h"
58 #include "src/sksl/ir/SkSLModifierFlags.h"
59 #include "src/sksl/ir/SkSLNop.h"
60 #include "src/sksl/ir/SkSLPostfixExpression.h"
61 #include "src/sksl/ir/SkSLPrefixExpression.h"
62 #include "src/sksl/ir/SkSLProgram.h"
63 #include "src/sksl/ir/SkSLProgramElement.h"
64 #include "src/sksl/ir/SkSLReturnStatement.h"
65 #include "src/sksl/ir/SkSLSetting.h"
66 #include "src/sksl/ir/SkSLStatement.h"
67 #include "src/sksl/ir/SkSLStructDefinition.h"
68 #include "src/sksl/ir/SkSLSwitchCase.h"
69 #include "src/sksl/ir/SkSLSwitchStatement.h"
70 #include "src/sksl/ir/SkSLSwizzle.h"
71 #include "src/sksl/ir/SkSLTernaryExpression.h"
72 #include "src/sksl/ir/SkSLType.h"
73 #include "src/sksl/ir/SkSLVarDeclarations.h"
74 #include "src/sksl/ir/SkSLVariable.h"
75 #include "src/sksl/ir/SkSLVariableReference.h"
76 #include "src/sksl/spirv.h"
77
78 #include <algorithm>
79 #include <cstddef>
80 #include <cstdint>
81 #include <functional>
82 #include <initializer_list>
83 #include <limits>
84 #include <memory>
85 #include <string>
86 #include <string_view>
87 #include <utility>
88 #include <vector>
89
90 using namespace skia_private;
91
92 namespace SkSL {
93
94 class MetalCodeGenerator : public CodeGenerator {
95 public:
MetalCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out,PrettyPrint pp)96 MetalCodeGenerator(const Context* context,
97 const ShaderCaps* caps,
98 const Program* program,
99 OutputStream* out,
100 PrettyPrint pp)
101 : CodeGenerator(context, caps, program, out)
102 , fReservedWords({"atan2", "rsqrt", "rint", "dfdx", "dfdy", "vertex", "fragment"})
103 , fLineEnding("\n")
104 , fPrettyPrint(pp) {}
105
106 bool generateCode() override;
107
108 protected:
109 using Precedence = OperatorPrecedence;
110
111 using Requirements = int;
112 static constexpr Requirements kNo_Requirements = 0;
113 static constexpr Requirements kInputs_Requirement = 1 << 0;
114 static constexpr Requirements kOutputs_Requirement = 1 << 1;
115 static constexpr Requirements kUniforms_Requirement = 1 << 2;
116 static constexpr Requirements kGlobals_Requirement = 1 << 3;
117 static constexpr Requirements kFragCoord_Requirement = 1 << 4;
118 static constexpr Requirements kSampleMaskIn_Requirement = 1 << 5;
119 static constexpr Requirements kVertexID_Requirement = 1 << 6;
120 static constexpr Requirements kInstanceID_Requirement = 1 << 7;
121 static constexpr Requirements kThreadgroups_Requirement = 1 << 8;
122
123 class GlobalStructVisitor;
124 void visitGlobalStruct(GlobalStructVisitor* visitor);
125
126 class ThreadgroupStructVisitor;
127 void visitThreadgroupStruct(ThreadgroupStructVisitor* visitor);
128
129 void write(std::string_view s);
130
131 void writeLine(std::string_view s = std::string_view());
132
133 void finishLine();
134
135 void writeHeader();
136
137 void writeSampler2DPolyfill();
138
139 void writeUniformStruct();
140
141 void writeInterpolatedAttributes(const Variable& var);
142
143 void writeInputStruct();
144
145 void writeOutputStruct();
146
147 void writeInterfaceBlocks();
148
149 void writeStructDefinitions();
150
151 void writeConstantVariables();
152
153 void writeFields(SkSpan<const Field> fields, Position pos);
154
155 int size(const Type* type, bool isPacked) const;
156
157 int alignment(const Type* type, bool isPacked) const;
158
159 void writeGlobalStruct();
160
161 void writeGlobalInit();
162
163 void writeThreadgroupStruct();
164
165 void writeThreadgroupInit();
166
167 void writePrecisionModifier();
168
169 std::string typeName(const Type& type);
170
171 void writeStructDefinition(const StructDefinition& s);
172
173 void writeType(const Type& type);
174
175 void writeExtension(const Extension& ext);
176
177 void writeInterfaceBlock(const InterfaceBlock& intf);
178
179 void writeFunctionRequirementParams(const FunctionDeclaration& f,
180 const char*& separator);
181
182 void writeFunctionRequirementArgs(const FunctionDeclaration& f, const char*& separator);
183
184 bool writeFunctionDeclaration(const FunctionDeclaration& f);
185
186 void writeFunction(const FunctionDefinition& f);
187
188 void writeFunctionPrototype(const FunctionPrototype& f);
189
190 void writeLayout(const Layout& layout);
191
192 void writeModifiers(ModifierFlags flags);
193
194 void writeVarInitializer(const Variable& var, const Expression& value);
195
196 void writeName(std::string_view name);
197
198 void writeVarDeclaration(const VarDeclaration& decl);
199
200 void writeFragCoord();
201
202 void writeVariableReference(const VariableReference& ref);
203
204 void writeExpression(const Expression& expr, Precedence parentPrecedence);
205
206 void writeMinAbsHack(Expression& absExpr, Expression& otherExpr);
207
208 std::string getInversePolyfill(const ExpressionArray& arguments);
209
210 std::string getBitcastIntrinsic(const Type& outType);
211
212 std::string getTempVariable(const Type& varType);
213
214 void writeFunctionCall(const FunctionCall& c);
215
216 bool matrixConstructHelperIsNeeded(const ConstructorCompound& c);
217 std::string getMatrixConstructHelper(const AnyConstructor& c);
218 void assembleMatrixFromMatrix(const Type& sourceMatrix, int columns, int rows);
219 void assembleMatrixFromExpressions(const AnyConstructor& ctor, int columns, int rows);
220
221 void writeMatrixCompMult();
222
223 void writeOuterProduct();
224
225 void writeMatrixTimesEqualHelper(const Type& left, const Type& right, const Type& result);
226
227 void writeMatrixDivisionHelpers(const Type& type);
228
229 void writeMatrixEqualityHelpers(const Type& left, const Type& right);
230
231 std::string getVectorFromMat2x2ConstructorHelper(const Type& matrixType);
232
233 void writeArrayEqualityHelpers(const Type& type);
234
235 void writeStructEqualityHelpers(const Type& type);
236
237 void writeEqualityHelpers(const Type& leftType, const Type& rightType);
238
239 void writeArgumentList(const ExpressionArray& arguments);
240
241 void writeSimpleIntrinsic(const FunctionCall& c);
242
243 bool writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind);
244
245 void writeConstructorCompound(const ConstructorCompound& c, Precedence parentPrecedence);
246
247 void writeConstructorCompoundVector(const ConstructorCompound& c, Precedence parentPrecedence);
248
249 void writeConstructorCompoundMatrix(const ConstructorCompound& c, Precedence parentPrecedence);
250
251 void writeConstructorMatrixResize(const ConstructorMatrixResize& c,
252 Precedence parentPrecedence);
253
254 void writeAnyConstructor(const AnyConstructor& c,
255 const char* leftBracket,
256 const char* rightBracket,
257 Precedence parentPrecedence);
258
259 void writeCastConstructor(const AnyConstructor& c,
260 const char* leftBracket,
261 const char* rightBracket,
262 Precedence parentPrecedence);
263
264 void writeConstructorArrayCast(const ConstructorArrayCast& c, Precedence parentPrecedence);
265
266 void writeFieldAccess(const FieldAccess& f);
267
268 void writeSwizzle(const Swizzle& swizzle);
269
270 // Returns `floatCxR(1.0, 1.0, 1.0, 1.0, ...)`.
271 std::string splatMatrixOf1(const Type& type);
272
273 // Splats a scalar expression across a matrix of arbitrary size.
274 void writeNumberAsMatrix(const Expression& expr, const Type& matrixType);
275
276 void writeBinaryExpressionElement(const Expression& expr,
277 Operator op,
278 const Expression& other,
279 Precedence precedence);
280
281 void writeBinaryExpression(const BinaryExpression& b, Precedence parentPrecedence);
282
283 void writeTernaryExpression(const TernaryExpression& t, Precedence parentPrecedence);
284
285 void writeIndexExpression(const IndexExpression& expr);
286
287 void writeIndexInnerExpression(const Expression& expr);
288
289 void writePrefixExpression(const PrefixExpression& p, Precedence parentPrecedence);
290
291 void writePostfixExpression(const PostfixExpression& p, Precedence parentPrecedence);
292
293 void writeLiteral(const Literal& f);
294
295 void writeStatement(const Statement& s);
296
297 void writeStatements(const StatementArray& statements);
298
299 void writeBlock(const Block& b);
300
301 void writeIfStatement(const IfStatement& stmt);
302
303 void writeForStatement(const ForStatement& f);
304
305 void writeDoStatement(const DoStatement& d);
306
307 void writeExpressionStatement(const ExpressionStatement& s);
308
309 void writeSwitchStatement(const SwitchStatement& s);
310
311 void writeReturnStatementFromMain();
312
313 void writeReturnStatement(const ReturnStatement& r);
314
315 void writeProgramElement(const ProgramElement& e);
316
317 Requirements requirements(const FunctionDeclaration& f);
318
319 Requirements requirements(const Statement* s);
320
321 // For compute shader main functions, writes and initializes the _in and _out structs (the
322 // instances, not the types themselves)
323 void writeComputeMainInputs();
324
325 int getUniformBinding(const Layout& layout);
326
327 int getUniformSet(const Layout& layout);
328
329 void writeWithIndexSubstitution(const std::function<void()>& fn);
330
331 skia_private::THashSet<std::string_view> fReservedWords;
332 skia_private::THashMap<const Type*, std::string> fInterfaceBlockNameMap;
333 int fAnonInterfaceCount = 0;
334 int fPaddingCount = 0;
335 const char* fLineEnding;
336 std::string fFunctionHeader;
337 StringStream fExtraFunctions;
338 StringStream fExtraFunctionPrototypes;
339 int fVarCount = 0;
340 int fIndentation = 0;
341 bool fAtLineStart = false;
342 // true if we have run into usages of dFdx / dFdy
343 bool fFoundDerivatives = false;
344 skia_private::THashMap<const FunctionDeclaration*, Requirements> fRequirements;
345 skia_private::THashSet<std::string> fHelpers;
346 int fUniformBuffer = -1;
347 std::string fRTFlipName;
348 const FunctionDeclaration* fCurrentFunction = nullptr;
349 int fSwizzleHelperCount = 0;
350 static constexpr char kTextureSuffix[] = "_Tex";
351 static constexpr char kSamplerSuffix[] = "_Smplr";
352
353 // If we might use an index expression more than once, we need to capture the result in a
354 // temporary variable to avoid double-evaluation. This should generally only occur when emitting
355 // a function call, since we need to polyfill GLSL-style out-parameter support. (skia:14130)
356 // The map holds <index-expression, temp-variable name>.
357 using IndexSubstitutionMap = skia_private::THashMap<const Expression*, std::string>;
358
359 // When fIndexSubstitution is null (usually), index-substitution does not need to be performed.
360 struct IndexSubstitutionData {
361 IndexSubstitutionMap fMap;
362 StringStream fMainStream;
363 StringStream fPrefixStream;
364 bool fCreateSubstitutes = true;
365 };
366 std::unique_ptr<IndexSubstitutionData> fIndexSubstitutionData;
367 PrettyPrint fPrettyPrint;
368
369 // Workaround/polyfill flags
370 bool fWrittenInverse2 = false, fWrittenInverse3 = false, fWrittenInverse4 = false;
371 bool fWrittenMatrixCompMult = false;
372 bool fWrittenOuterProduct = false;
373 };
374
operator_name(Operator op)375 static const char* operator_name(Operator op) {
376 switch (op.kind()) {
377 case Operator::Kind::LOGICALXOR: return " != ";
378 default: return op.operatorName();
379 }
380 }
381
382 class MetalCodeGenerator::GlobalStructVisitor {
383 public:
384 virtual ~GlobalStructVisitor() = default;
visitInterfaceBlock(const InterfaceBlock & block,std::string_view blockName)385 virtual void visitInterfaceBlock(const InterfaceBlock& block, std::string_view blockName) {}
visitTexture(const Type & type,std::string_view name)386 virtual void visitTexture(const Type& type, std::string_view name) {}
visitSampler(const Type & type,std::string_view name)387 virtual void visitSampler(const Type& type, std::string_view name) {}
visitConstantVariable(const VarDeclaration & decl)388 virtual void visitConstantVariable(const VarDeclaration& decl) {}
visitNonconstantVariable(const Variable & var,const Expression * value)389 virtual void visitNonconstantVariable(const Variable& var, const Expression* value) {}
390 };
391
392 class MetalCodeGenerator::ThreadgroupStructVisitor {
393 public:
394 virtual ~ThreadgroupStructVisitor() = default;
395 virtual void visitNonconstantVariable(const Variable& var) = 0;
396 };
397
write(std::string_view s)398 void MetalCodeGenerator::write(std::string_view s) {
399 if (s.empty()) {
400 return;
401 }
402 if (fAtLineStart && fPrettyPrint == PrettyPrint::kYes) {
403 for (int i = 0; i < fIndentation; i++) {
404 fOut->writeText(" ");
405 }
406 }
407 fOut->writeText(std::string(s).c_str());
408 fAtLineStart = false;
409 }
410
writeLine(std::string_view s)411 void MetalCodeGenerator::writeLine(std::string_view s) {
412 this->write(s);
413 fOut->writeText(fLineEnding);
414 fAtLineStart = true;
415 }
416
finishLine()417 void MetalCodeGenerator::finishLine() {
418 if (!fAtLineStart) {
419 this->writeLine();
420 }
421 }
422
writeExtension(const Extension & ext)423 void MetalCodeGenerator::writeExtension(const Extension& ext) {
424 this->writeLine("#extension " + std::string(ext.name()) + " : enable");
425 }
426
typeName(const Type & raw)427 std::string MetalCodeGenerator::typeName(const Type& raw) {
428 // we need to know the modifiers for textures
429 const Type& type = raw.resolve().scalarTypeForLiteral();
430 switch (type.typeKind()) {
431 case Type::TypeKind::kArray: {
432 std::string typeName = this->typeName(type.componentType());
433 if (type.isUnsizedArray()) {
434 return String::printf("const device %s*", typeName.c_str());
435 } else {
436 SkASSERTF(type.columns() > 0, "invalid array size: %s", type.description().c_str());
437 return String::printf("array<%s, %d>", typeName.c_str(), type.columns());
438 }
439 }
440 case Type::TypeKind::kVector:
441 return this->typeName(type.componentType()) + std::to_string(type.columns());
442
443 case Type::TypeKind::kMatrix:
444 return this->typeName(type.componentType()) + std::to_string(type.columns()) + "x" +
445 std::to_string(type.rows());
446
447 case Type::TypeKind::kSampler:
448 if (type.dimensions() != SpvDim2D) {
449 fContext.fErrors->error(Position(), "Unsupported texture dimensions");
450 }
451 return "sampler2D";
452
453 case Type::TypeKind::kTexture:
454 switch (type.textureAccess()) {
455 case Type::TextureAccess::kSample: return "texture2d<half>";
456 case Type::TextureAccess::kRead: return "texture2d<half, access::read>";
457 case Type::TextureAccess::kWrite: return "texture2d<half, access::write>";
458 case Type::TextureAccess::kReadWrite: return "texture2d<half, access::read_write>";
459 default: break;
460 }
461 SkUNREACHABLE;
462
463 case Type::TypeKind::kAtomic:
464 // SkSL currently only supports the atomicUint type.
465 SkASSERT(type.matches(*fContext.fTypes.fAtomicUInt));
466 return "atomic_uint";
467
468 default:
469 return std::string(type.name());
470 }
471 }
472
writeStructDefinition(const StructDefinition & s)473 void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
474 const Type& type = s.type();
475 this->writeLine("struct " + type.displayName() + " {");
476 fIndentation++;
477 this->writeFields(type.fields(), type.fPosition);
478 fIndentation--;
479 this->writeLine("};");
480 }
481
writeType(const Type & type)482 void MetalCodeGenerator::writeType(const Type& type) {
483 this->write(this->typeName(type));
484 }
485
writeExpression(const Expression & expr,Precedence parentPrecedence)486 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
487 switch (expr.kind()) {
488 case Expression::Kind::kBinary:
489 this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
490 break;
491 case Expression::Kind::kConstructorArray:
492 case Expression::Kind::kConstructorStruct:
493 this->writeAnyConstructor(expr.asAnyConstructor(), "{", "}", parentPrecedence);
494 break;
495 case Expression::Kind::kConstructorArrayCast:
496 this->writeConstructorArrayCast(expr.as<ConstructorArrayCast>(), parentPrecedence);
497 break;
498 case Expression::Kind::kConstructorCompound:
499 this->writeConstructorCompound(expr.as<ConstructorCompound>(), parentPrecedence);
500 break;
501 case Expression::Kind::kConstructorDiagonalMatrix:
502 case Expression::Kind::kConstructorSplat:
503 this->writeAnyConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
504 break;
505 case Expression::Kind::kConstructorMatrixResize:
506 this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(),
507 parentPrecedence);
508 break;
509 case Expression::Kind::kConstructorScalarCast:
510 case Expression::Kind::kConstructorCompoundCast:
511 this->writeCastConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
512 break;
513 case Expression::Kind::kEmpty:
514 this->write("false");
515 break;
516 case Expression::Kind::kFieldAccess:
517 this->writeFieldAccess(expr.as<FieldAccess>());
518 break;
519 case Expression::Kind::kLiteral:
520 this->writeLiteral(expr.as<Literal>());
521 break;
522 case Expression::Kind::kFunctionCall:
523 this->writeFunctionCall(expr.as<FunctionCall>());
524 break;
525 case Expression::Kind::kPrefix:
526 this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
527 break;
528 case Expression::Kind::kPostfix:
529 this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
530 break;
531 case Expression::Kind::kSetting:
532 this->writeExpression(*expr.as<Setting>().toLiteral(fCaps), parentPrecedence);
533 break;
534 case Expression::Kind::kSwizzle:
535 this->writeSwizzle(expr.as<Swizzle>());
536 break;
537 case Expression::Kind::kVariableReference:
538 this->writeVariableReference(expr.as<VariableReference>());
539 break;
540 case Expression::Kind::kTernary:
541 this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
542 break;
543 case Expression::Kind::kIndex:
544 this->writeIndexExpression(expr.as<IndexExpression>());
545 break;
546 default:
547 SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
548 break;
549 }
550 }
551
552 // returns true if we should pass by reference instead of by value
pass_by_reference(const Type & type,ModifierFlags flags)553 static bool pass_by_reference(const Type& type, ModifierFlags flags) {
554 return (flags & ModifierFlag::kOut) && !type.isUnsizedArray();
555 }
556
557 // returns true if we need to specify an address space modifier
needs_address_space(const Type & type,ModifierFlags modifiers)558 static bool needs_address_space(const Type& type, ModifierFlags modifiers) {
559 return type.isUnsizedArray() || pass_by_reference(type, modifiers);
560 }
561
562 // returns true if the InterfaceBlock has the `buffer` modifier
is_buffer(const InterfaceBlock & block)563 static bool is_buffer(const InterfaceBlock& block) {
564 return block.var()->modifierFlags().isBuffer();
565 }
566
567 // returns true if the InterfaceBlock has the `readonly` modifier
is_readonly(const InterfaceBlock & block)568 static bool is_readonly(const InterfaceBlock& block) {
569 return block.var()->modifierFlags().isReadOnly();
570 }
571
getBitcastIntrinsic(const Type & outType)572 std::string MetalCodeGenerator::getBitcastIntrinsic(const Type& outType) {
573 return "as_type<" + outType.displayName() + ">";
574 }
575
writeWithIndexSubstitution(const std::function<void ()> & fn)576 void MetalCodeGenerator::writeWithIndexSubstitution(const std::function<void()>& fn) {
577 auto oldIndexSubstitutionData = std::make_unique<IndexSubstitutionData>();
578 fIndexSubstitutionData.swap(oldIndexSubstitutionData);
579
580 // Invoke our helper function, with output going into our temporary stream.
581 {
582 AutoOutputStream outputToMainStream(this, &fIndexSubstitutionData->fMainStream);
583 fn();
584 }
585
586 if (fIndexSubstitutionData->fPrefixStream.bytesWritten() == 0) {
587 // Emit the main stream into the program as-is.
588 write_stringstream(fIndexSubstitutionData->fMainStream, *fOut);
589 } else {
590 // Emit the prefix stream and main stream into the program as a sequence-expression.
591 // (Each prefix-expression must end with a comma.)
592 this->write("(");
593 write_stringstream(fIndexSubstitutionData->fPrefixStream, *fOut);
594 write_stringstream(fIndexSubstitutionData->fMainStream, *fOut);
595 this->write(")");
596 }
597
598 fIndexSubstitutionData.swap(oldIndexSubstitutionData);
599 }
600
writeFunctionCall(const FunctionCall & c)601 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
602 const FunctionDeclaration& function = c.function();
603
604 // Many intrinsics need to be rewritten in Metal.
605 if (function.isIntrinsic()) {
606 if (this->writeIntrinsicCall(c, function.intrinsicKind())) {
607 return;
608 }
609 }
610
611 // Look for out parameters. SkSL guarantees GLSL's out-param semantics, and we need to emulate
612 // it if an out-param is encountered. (Specifically, out-parameters in GLSL are only written
613 // back to the original variable at the end of the function call; also, swizzles are supported,
614 // whereas Metal doesn't allow a swizzle to be passed to a `floatN&`.)
615 const ExpressionArray& arguments = c.arguments();
616 SkSpan<Variable* const> parameters = function.parameters();
617 SkASSERT(SkToSizeT(arguments.size()) == parameters.size());
618
619 bool foundOutParam = false;
620 STArray<16, std::string> scratchVarName;
621 scratchVarName.push_back_n(arguments.size(), std::string());
622
623 for (int index = 0; index < arguments.size(); ++index) {
624 // If this is an out parameter...
625 if (parameters[index]->modifierFlags() & ModifierFlag::kOut) {
626 // Assignability was verified at IRGeneration time, so this should always succeed.
627 [[maybe_unused]] Analysis::AssignmentInfo info;
628 SkASSERT(Analysis::IsAssignable(*arguments[index], &info));
629
630 scratchVarName[index] = this->getTempVariable(arguments[index]->type());
631 foundOutParam = true;
632 }
633 }
634
635 if (foundOutParam) {
636 // Out parameters need to be written back to at the end of the function. To do this, we
637 // generate a comma-separated sequence expression that copies the out-param expressions into
638 // our temporary variables, calls the original function--storing its result into a scratch
639 // variable--and then writes the temp variables back into the original out params using the
640 // original out-param expressions. This would look something like:
641 //
642 // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp), _skResult)
643 // ^ ^ ^ ^
644 // return value passes copy of argument copies back into argument return value
645 //
646 // While these expressions are complex, they allow us to maintain the proper sequencing that
647 // is necessary for out-parameters, as well as allowing us to support things like swizzles
648 // and array indices which Metal references cannot natively handle.
649
650 // We will be emitting inout expressions twice, so it's important to enable index
651 // substitution in case we encounter any side-effecting indexes.
652 this->writeWithIndexSubstitution([&] {
653 this->write("((");
654
655 // ((_skResult =
656 std::string scratchResultName;
657 if (!function.returnType().isVoid()) {
658 scratchResultName = this->getTempVariable(c.type());
659 this->write(scratchResultName);
660 this->write(" = ");
661 }
662
663 // ((_skResult = func(
664 this->write(function.mangledName());
665 this->write("(");
666
667 // ((_skResult = func((_skTemp = myOutParam.x), 123
668 const char* separator = "";
669 this->writeFunctionRequirementArgs(function, separator);
670
671 for (int i = 0; i < arguments.size(); ++i) {
672 this->write(separator);
673 separator = ", ";
674 if (parameters[i]->modifierFlags() & ModifierFlag::kOut) {
675 SkASSERT(!scratchVarName[i].empty());
676 if (parameters[i]->modifierFlags() & ModifierFlag::kIn) {
677 // `inout` parameters initialize the scratch variable with the passed-in
678 // argument's value.
679 this->write("(");
680 this->write(scratchVarName[i]);
681 this->write(" = ");
682 this->writeExpression(*arguments[i], Precedence::kAssignment);
683 this->write(")");
684 } else {
685 // `out` parameters pass a reference to the uninitialized scratch variable.
686 this->write(scratchVarName[i]);
687 }
688 } else {
689 // Regular parameters are passed as-is.
690 this->writeExpression(*arguments[i], Precedence::kSequence);
691 }
692 }
693
694 // ((_skResult = func((_skTemp = myOutParam.x), 123))
695 this->write("))");
696
697 // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp)
698 for (int i = 0; i < arguments.size(); ++i) {
699 if (!scratchVarName[i].empty()) {
700 this->write(", (");
701 this->writeExpression(*arguments[i], Precedence::kAssignment);
702 this->write(" = ");
703 this->write(scratchVarName[i]);
704 this->write(")");
705 }
706 }
707
708 // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp),
709 // _skResult
710 if (!scratchResultName.empty()) {
711 this->write(", ");
712 this->write(scratchResultName);
713 }
714
715 // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp),
716 // _skResult)
717 this->write(")");
718 });
719 } else {
720 // Emit the function call as-is, only prepending the required arguments.
721 this->write(function.mangledName());
722 this->write("(");
723 const char* separator = "";
724 this->writeFunctionRequirementArgs(function, separator);
725 for (int i = 0; i < arguments.size(); ++i) {
726 SkASSERT(scratchVarName[i].empty());
727 this->write(separator);
728 separator = ", ";
729 this->writeExpression(*arguments[i], Precedence::kSequence);
730 }
731 this->write(")");
732 }
733 }
734
735 static constexpr char kInverse2x2[] = R"(
736 template <typename T>
737 matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
738 return matrix<T, 2, 2>(m[1].y, -m[0].y, -m[1].x, m[0].x) * (1/determinant(m));
739 }
740 )";
741
742 static constexpr char kInverse3x3[] = R"(
743 template <typename T>
744 matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
745 T
746 a00 = m[0].x, a01 = m[0].y, a02 = m[0].z,
747 a10 = m[1].x, a11 = m[1].y, a12 = m[1].z,
748 a20 = m[2].x, a21 = m[2].y, a22 = m[2].z,
749 b01 = a22*a11 - a12*a21,
750 b11 = -a22*a10 + a12*a20,
751 b21 = a21*a10 - a11*a20,
752 det = a00*b01 + a01*b11 + a02*b21;
753 return matrix<T, 3, 3>(
754 b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
755 b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
756 b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
757 }
758 )";
759
760 static constexpr char kInverse4x4[] = R"(
761 template <typename T>
762 matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
763 T
764 a00 = m[0].x, a01 = m[0].y, a02 = m[0].z, a03 = m[0].w,
765 a10 = m[1].x, a11 = m[1].y, a12 = m[1].z, a13 = m[1].w,
766 a20 = m[2].x, a21 = m[2].y, a22 = m[2].z, a23 = m[2].w,
767 a30 = m[3].x, a31 = m[3].y, a32 = m[3].z, a33 = m[3].w,
768 b00 = a00*a11 - a01*a10,
769 b01 = a00*a12 - a02*a10,
770 b02 = a00*a13 - a03*a10,
771 b03 = a01*a12 - a02*a11,
772 b04 = a01*a13 - a03*a11,
773 b05 = a02*a13 - a03*a12,
774 b06 = a20*a31 - a21*a30,
775 b07 = a20*a32 - a22*a30,
776 b08 = a20*a33 - a23*a30,
777 b09 = a21*a32 - a22*a31,
778 b10 = a21*a33 - a23*a31,
779 b11 = a22*a33 - a23*a32,
780 det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
781 return matrix<T, 4, 4>(
782 a11*b11 - a12*b10 + a13*b09,
783 a02*b10 - a01*b11 - a03*b09,
784 a31*b05 - a32*b04 + a33*b03,
785 a22*b04 - a21*b05 - a23*b03,
786 a12*b08 - a10*b11 - a13*b07,
787 a00*b11 - a02*b08 + a03*b07,
788 a32*b02 - a30*b05 - a33*b01,
789 a20*b05 - a22*b02 + a23*b01,
790 a10*b10 - a11*b08 + a13*b06,
791 a01*b08 - a00*b10 - a03*b06,
792 a30*b04 - a31*b02 + a33*b00,
793 a21*b02 - a20*b04 - a23*b00,
794 a11*b07 - a10*b09 - a12*b06,
795 a00*b09 - a01*b07 + a02*b06,
796 a31*b01 - a30*b03 - a32*b00,
797 a20*b03 - a21*b01 + a22*b00) * (1/det);
798 }
799 )";
800
getInversePolyfill(const ExpressionArray & arguments)801 std::string MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments) {
802 // Only use polyfills for a function taking a single-argument square matrix.
803 SkASSERT(arguments.size() == 1);
804 const Type& type = arguments.front()->type();
805 if (type.isMatrix() && type.rows() == type.columns()) {
806 switch (type.rows()) {
807 case 2:
808 if (!fWrittenInverse2) {
809 fWrittenInverse2 = true;
810 fExtraFunctions.writeText(kInverse2x2);
811 }
812 return "mat2_inverse";
813 case 3:
814 if (!fWrittenInverse3) {
815 fWrittenInverse3 = true;
816 fExtraFunctions.writeText(kInverse3x3);
817 }
818 return "mat3_inverse";
819 case 4:
820 if (!fWrittenInverse4) {
821 fWrittenInverse4 = true;
822 fExtraFunctions.writeText(kInverse4x4);
823 }
824 return "mat4_inverse";
825 }
826 }
827 SkDEBUGFAILF("no polyfill for inverse(%s)", type.description().c_str());
828 return "inverse";
829 }
830
writeMatrixCompMult()831 void MetalCodeGenerator::writeMatrixCompMult() {
832 static constexpr char kMatrixCompMult[] = R"(
833 template <typename T, int C, int R>
834 matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
835 for (int c = 0; c < C; ++c) { a[c] *= b[c]; }
836 return a;
837 }
838 )";
839 if (!fWrittenMatrixCompMult) {
840 fWrittenMatrixCompMult = true;
841 fExtraFunctions.writeText(kMatrixCompMult);
842 }
843 }
844
writeOuterProduct()845 void MetalCodeGenerator::writeOuterProduct() {
846 static constexpr char kOuterProduct[] = R"(
847 template <typename T, int C, int R>
848 matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
849 matrix<T, C, R> m;
850 for (int c = 0; c < C; ++c) { m[c] = a * b[c]; }
851 return m;
852 }
853 )";
854 if (!fWrittenOuterProduct) {
855 fWrittenOuterProduct = true;
856 fExtraFunctions.writeText(kOuterProduct);
857 }
858 }
859
getTempVariable(const Type & type)860 std::string MetalCodeGenerator::getTempVariable(const Type& type) {
861 std::string tempVar = "_skTemp" + std::to_string(fVarCount++);
862 this->fFunctionHeader += " " + this->typeName(type) + " " + tempVar + ";\n";
863 return tempVar;
864 }
865
writeSimpleIntrinsic(const FunctionCall & c)866 void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
867 // Write out an intrinsic function call exactly as-is. No muss no fuss.
868 this->write(c.function().name());
869 this->writeArgumentList(c.arguments());
870 }
871
writeArgumentList(const ExpressionArray & arguments)872 void MetalCodeGenerator::writeArgumentList(const ExpressionArray& arguments) {
873 this->write("(");
874 const char* separator = "";
875 for (const std::unique_ptr<Expression>& arg : arguments) {
876 this->write(separator);
877 separator = ", ";
878 this->writeExpression(*arg, Precedence::kSequence);
879 }
880 this->write(")");
881 }
882
writeIntrinsicCall(const FunctionCall & c,IntrinsicKind kind)883 bool MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind) {
884 const ExpressionArray& arguments = c.arguments();
885 switch (kind) {
886 case k_textureRead_IntrinsicKind: {
887 this->writeExpression(*arguments[0], Precedence::kExpression);
888 this->write(".read(");
889 this->writeExpression(*arguments[1], Precedence::kSequence);
890 this->write(")");
891 return true;
892 }
893 case k_textureWrite_IntrinsicKind: {
894 this->writeExpression(*arguments[0], Precedence::kExpression);
895 this->write(".write(");
896 this->writeExpression(*arguments[2], Precedence::kSequence);
897 this->write(", ");
898 this->writeExpression(*arguments[1], Precedence::kSequence);
899 this->write(")");
900 return true;
901 }
902 case k_textureWidth_IntrinsicKind: {
903 this->writeExpression(*arguments[0], Precedence::kExpression);
904 this->write(".get_width()");
905 return true;
906 }
907 case k_textureHeight_IntrinsicKind: {
908 this->writeExpression(*arguments[0], Precedence::kExpression);
909 this->write(".get_height()");
910 return true;
911 }
912 case k_mod_IntrinsicKind: {
913 // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
914 std::string tmpX = this->getTempVariable(arguments[0]->type());
915 std::string tmpY = this->getTempVariable(arguments[1]->type());
916 this->write("(" + tmpX + " = ");
917 this->writeExpression(*arguments[0], Precedence::kSequence);
918 this->write(", " + tmpY + " = ");
919 this->writeExpression(*arguments[1], Precedence::kSequence);
920 this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))");
921 return true;
922 }
923 case k_pow_IntrinsicKind: {
924 // The Metal equivalent to GLSL's pow() is powr(). Metal's pow() allows negative base
925 // values, which is presumably more expensive to compute.
926 this->write("powr(");
927 this->writeExpression(*arguments[0], Precedence::kSequence);
928 this->write(", ");
929 this->writeExpression(*arguments[1], Precedence::kSequence);
930 this->write(")");
931 return true;
932 }
933 // GLSL declares scalar versions of most geometric intrinsics, but these don't exist in MSL
934 case k_distance_IntrinsicKind: {
935 if (arguments[0]->type().columns() == 1) {
936 this->write("abs(");
937 this->writeExpression(*arguments[0], Precedence::kAdditive);
938 this->write(" - ");
939 this->writeExpression(*arguments[1], Precedence::kAdditive);
940 this->write(")");
941 } else {
942 this->writeSimpleIntrinsic(c);
943 }
944 return true;
945 }
946 case k_dot_IntrinsicKind: {
947 if (arguments[0]->type().columns() == 1) {
948 this->write("(");
949 this->writeExpression(*arguments[0], Precedence::kMultiplicative);
950 this->write(" * ");
951 this->writeExpression(*arguments[1], Precedence::kMultiplicative);
952 this->write(")");
953 } else {
954 this->writeSimpleIntrinsic(c);
955 }
956 return true;
957 }
958 case k_faceforward_IntrinsicKind: {
959 if (arguments[0]->type().columns() == 1) {
960 // ((((Nref) * (I) < 0) ? 1 : -1) * (N))
961 this->write("((((");
962 this->writeExpression(*arguments[2], Precedence::kSequence);
963 this->write(") * (");
964 this->writeExpression(*arguments[1], Precedence::kSequence);
965 this->write(") < 0) ? 1 : -1) * (");
966 this->writeExpression(*arguments[0], Precedence::kSequence);
967 this->write("))");
968 } else {
969 this->writeSimpleIntrinsic(c);
970 }
971 return true;
972 }
973 case k_length_IntrinsicKind: {
974 this->write(arguments[0]->type().columns() == 1 ? "abs(" : "length(");
975 this->writeExpression(*arguments[0], Precedence::kSequence);
976 this->write(")");
977 return true;
978 }
979 case k_normalize_IntrinsicKind: {
980 this->write(arguments[0]->type().columns() == 1 ? "sign(" : "normalize(");
981 this->writeExpression(*arguments[0], Precedence::kSequence);
982 this->write(")");
983 return true;
984 }
985 case k_packUnorm2x16_IntrinsicKind: {
986 this->write("pack_float_to_unorm2x16(");
987 this->writeExpression(*arguments[0], Precedence::kSequence);
988 this->write(")");
989 return true;
990 }
991 case k_unpackUnorm2x16_IntrinsicKind: {
992 this->write("unpack_unorm2x16_to_float(");
993 this->writeExpression(*arguments[0], Precedence::kSequence);
994 this->write(")");
995 return true;
996 }
997 case k_packSnorm2x16_IntrinsicKind: {
998 this->write("pack_float_to_snorm2x16(");
999 this->writeExpression(*arguments[0], Precedence::kSequence);
1000 this->write(")");
1001 return true;
1002 }
1003 case k_unpackSnorm2x16_IntrinsicKind: {
1004 this->write("unpack_snorm2x16_to_float(");
1005 this->writeExpression(*arguments[0], Precedence::kSequence);
1006 this->write(")");
1007 return true;
1008 }
1009 case k_packUnorm4x8_IntrinsicKind: {
1010 this->write("pack_float_to_unorm4x8(");
1011 this->writeExpression(*arguments[0], Precedence::kSequence);
1012 this->write(")");
1013 return true;
1014 }
1015 case k_unpackUnorm4x8_IntrinsicKind: {
1016 this->write("unpack_unorm4x8_to_float(");
1017 this->writeExpression(*arguments[0], Precedence::kSequence);
1018 this->write(")");
1019 return true;
1020 }
1021 case k_packSnorm4x8_IntrinsicKind: {
1022 this->write("pack_float_to_snorm4x8(");
1023 this->writeExpression(*arguments[0], Precedence::kSequence);
1024 this->write(")");
1025 return true;
1026 }
1027 case k_unpackSnorm4x8_IntrinsicKind: {
1028 this->write("unpack_snorm4x8_to_float(");
1029 this->writeExpression(*arguments[0], Precedence::kSequence);
1030 this->write(")");
1031 return true;
1032 }
1033 case k_packHalf2x16_IntrinsicKind: {
1034 this->write("as_type<uint>(half2(");
1035 this->writeExpression(*arguments[0], Precedence::kSequence);
1036 this->write("))");
1037 return true;
1038 }
1039 case k_unpackHalf2x16_IntrinsicKind: {
1040 this->write("float2(as_type<half2>(");
1041 this->writeExpression(*arguments[0], Precedence::kSequence);
1042 this->write("))");
1043 return true;
1044 }
1045 case k_floatBitsToInt_IntrinsicKind:
1046 case k_floatBitsToUint_IntrinsicKind:
1047 case k_intBitsToFloat_IntrinsicKind:
1048 case k_uintBitsToFloat_IntrinsicKind: {
1049 this->write(this->getBitcastIntrinsic(c.type()));
1050 this->write("(");
1051 this->writeExpression(*arguments[0], Precedence::kSequence);
1052 this->write(")");
1053 return true;
1054 }
1055 case k_degrees_IntrinsicKind: {
1056 this->write("((");
1057 this->writeExpression(*arguments[0], Precedence::kSequence);
1058 this->write(") * 57.2957795)");
1059 return true;
1060 }
1061 case k_radians_IntrinsicKind: {
1062 this->write("((");
1063 this->writeExpression(*arguments[0], Precedence::kSequence);
1064 this->write(") * 0.0174532925)");
1065 return true;
1066 }
1067 case k_dFdx_IntrinsicKind: {
1068 this->write("dfdx");
1069 this->writeArgumentList(c.arguments());
1070 return true;
1071 }
1072 case k_dFdy_IntrinsicKind: {
1073 if (!fRTFlipName.empty()) {
1074 this->write("(" + fRTFlipName + ".y * dfdy");
1075 } else {
1076 this->write("(dfdy");
1077 }
1078 this->writeArgumentList(c.arguments());
1079 this->write(")");
1080 return true;
1081 }
1082 case k_inverse_IntrinsicKind: {
1083 this->write(this->getInversePolyfill(arguments));
1084 this->writeArgumentList(c.arguments());
1085 return true;
1086 }
1087 case k_inversesqrt_IntrinsicKind: {
1088 this->write("rsqrt");
1089 this->writeArgumentList(c.arguments());
1090 return true;
1091 }
1092 case k_atan_IntrinsicKind: {
1093 this->write(c.arguments().size() == 2 ? "atan2" : "atan");
1094 this->writeArgumentList(c.arguments());
1095 return true;
1096 }
1097 case k_reflect_IntrinsicKind: {
1098 if (arguments[0]->type().columns() == 1) {
1099 // We need to synthesize `I - 2 * N * I * N`.
1100 std::string tmpI = this->getTempVariable(arguments[0]->type());
1101 std::string tmpN = this->getTempVariable(arguments[1]->type());
1102
1103 // (_skTempI = ...
1104 this->write("(" + tmpI + " = ");
1105 this->writeExpression(*arguments[0], Precedence::kSequence);
1106
1107 // , _skTempN = ...
1108 this->write(", " + tmpN + " = ");
1109 this->writeExpression(*arguments[1], Precedence::kSequence);
1110
1111 // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
1112 this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
1113 } else {
1114 this->writeSimpleIntrinsic(c);
1115 }
1116 return true;
1117 }
1118 case k_refract_IntrinsicKind: {
1119 if (arguments[0]->type().columns() == 1) {
1120 // Metal does implement refract for vectors; rather than reimplementing refract from
1121 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
1122 this->write("(refract(float2(");
1123 this->writeExpression(*arguments[0], Precedence::kSequence);
1124 this->write(", 0), float2(");
1125 this->writeExpression(*arguments[1], Precedence::kSequence);
1126 this->write(", 0), ");
1127 this->writeExpression(*arguments[2], Precedence::kSequence);
1128 this->write(").x)");
1129 } else {
1130 this->writeSimpleIntrinsic(c);
1131 }
1132 return true;
1133 }
1134 case k_roundEven_IntrinsicKind: {
1135 this->write("rint");
1136 this->writeArgumentList(c.arguments());
1137 return true;
1138 }
1139 case k_bitCount_IntrinsicKind: {
1140 this->write("popcount(");
1141 this->writeExpression(*arguments[0], Precedence::kSequence);
1142 this->write(")");
1143 return true;
1144 }
1145 case k_findLSB_IntrinsicKind: {
1146 // Create a temp variable to store the expression, to avoid double-evaluating it.
1147 std::string skTemp = this->getTempVariable(arguments[0]->type());
1148 std::string exprType = this->typeName(arguments[0]->type());
1149
1150 // ctz returns numbits(type) on zero inputs; GLSL documents it as generating -1 instead.
1151 // Use select to detect zero inputs and force a -1 result.
1152
1153 // (_skTemp1 = (.....), select(ctz(_skTemp1), int4(-1), _skTemp1 == int4(0)))
1154 this->write("(");
1155 this->write(skTemp);
1156 this->write(" = (");
1157 this->writeExpression(*arguments[0], Precedence::kSequence);
1158 this->write("), select(ctz(");
1159 this->write(skTemp);
1160 this->write("), ");
1161 this->write(exprType);
1162 this->write("(-1), ");
1163 this->write(skTemp);
1164 this->write(" == ");
1165 this->write(exprType);
1166 this->write("(0)))");
1167 return true;
1168 }
1169 case k_findMSB_IntrinsicKind: {
1170 // Create a temp variable to store the expression, to avoid double-evaluating it.
1171 std::string skTemp1 = this->getTempVariable(arguments[0]->type());
1172 std::string exprType = this->typeName(arguments[0]->type());
1173
1174 // GLSL findMSB is actually quite different from Metal's clz:
1175 // - For signed negative numbers, it returns the first zero bit, not the first one bit!
1176 // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
1177
1178 // (_skTemp1 = (.....),
1179 this->write("(");
1180 this->write(skTemp1);
1181 this->write(" = (");
1182 this->writeExpression(*arguments[0], Precedence::kSequence);
1183 this->write("), ");
1184
1185 // Signed input types might be negative; we need another helper variable to negate the
1186 // input (since we can only find one bits, not zero bits).
1187 std::string skTemp2;
1188 if (arguments[0]->type().isSigned()) {
1189 // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
1190 skTemp2 = this->getTempVariable(arguments[0]->type());
1191 this->write(skTemp2);
1192 this->write(" = (select(");
1193 this->write(skTemp1);
1194 this->write(", ~");
1195 this->write(skTemp1);
1196 this->write(", ");
1197 this->write(skTemp1);
1198 this->write(" < 0)), ");
1199 } else {
1200 skTemp2 = skTemp1;
1201 }
1202
1203 // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
1204 this->write("select(");
1205 this->write(this->typeName(c.type()));
1206 this->write("(clz(");
1207 this->write(skTemp2);
1208 this->write(")), ");
1209 this->write(this->typeName(c.type()));
1210 this->write("(-1), ");
1211 this->write(skTemp2);
1212 this->write(" == ");
1213 this->write(exprType);
1214 this->write("(0)))");
1215 return true;
1216 }
1217 case k_sign_IntrinsicKind: {
1218 if (arguments[0]->type().componentType().isInteger()) {
1219 // Create a temp variable to store the expression, to avoid double-evaluating it.
1220 std::string skTemp = this->getTempVariable(arguments[0]->type());
1221 std::string exprType = this->typeName(arguments[0]->type());
1222
1223 // (_skTemp = (.....),
1224 this->write("(");
1225 this->write(skTemp);
1226 this->write(" = (");
1227 this->writeExpression(*arguments[0], Precedence::kSequence);
1228 this->write("), ");
1229
1230 // ... select(select(int4(0), int4(-1), _skTemp < 0), int4(1), _skTemp > 0))
1231 this->write("select(select(");
1232 this->write(exprType);
1233 this->write("(0), ");
1234 this->write(exprType);
1235 this->write("(-1), ");
1236 this->write(skTemp);
1237 this->write(" < 0), ");
1238 this->write(exprType);
1239 this->write("(1), ");
1240 this->write(skTemp);
1241 this->write(" > 0))");
1242 } else {
1243 this->writeSimpleIntrinsic(c);
1244 }
1245 return true;
1246 }
1247 case k_matrixCompMult_IntrinsicKind: {
1248 this->writeMatrixCompMult();
1249 this->writeSimpleIntrinsic(c);
1250 return true;
1251 }
1252 case k_outerProduct_IntrinsicKind: {
1253 this->writeOuterProduct();
1254 this->writeSimpleIntrinsic(c);
1255 return true;
1256 }
1257 case k_mix_IntrinsicKind: {
1258 SkASSERT(c.arguments().size() == 3);
1259 if (arguments[2]->type().componentType().isBoolean()) {
1260 // The Boolean forms of GLSL mix() use the select() intrinsic in Metal.
1261 this->write("select");
1262 this->writeArgumentList(c.arguments());
1263 return true;
1264 }
1265 // The basic form of mix() is supported by Metal as-is.
1266 this->writeSimpleIntrinsic(c);
1267 return true;
1268 }
1269 case k_equal_IntrinsicKind:
1270 case k_greaterThan_IntrinsicKind:
1271 case k_greaterThanEqual_IntrinsicKind:
1272 case k_lessThan_IntrinsicKind:
1273 case k_lessThanEqual_IntrinsicKind:
1274 case k_notEqual_IntrinsicKind: {
1275 this->write("(");
1276 this->writeExpression(*c.arguments()[0], Precedence::kRelational);
1277 switch (kind) {
1278 case k_equal_IntrinsicKind:
1279 this->write(" == ");
1280 break;
1281 case k_notEqual_IntrinsicKind:
1282 this->write(" != ");
1283 break;
1284 case k_lessThan_IntrinsicKind:
1285 this->write(" < ");
1286 break;
1287 case k_lessThanEqual_IntrinsicKind:
1288 this->write(" <= ");
1289 break;
1290 case k_greaterThan_IntrinsicKind:
1291 this->write(" > ");
1292 break;
1293 case k_greaterThanEqual_IntrinsicKind:
1294 this->write(" >= ");
1295 break;
1296 default:
1297 SK_ABORT("unsupported comparison intrinsic kind");
1298 }
1299 this->writeExpression(*c.arguments()[1], Precedence::kRelational);
1300 this->write(")");
1301 return true;
1302 }
1303 case k_storageBarrier_IntrinsicKind:
1304 this->write("threadgroup_barrier(mem_flags::mem_device)");
1305 return true;
1306 case k_workgroupBarrier_IntrinsicKind:
1307 this->write("threadgroup_barrier(mem_flags::mem_threadgroup)");
1308 return true;
1309 case k_atomicAdd_IntrinsicKind:
1310 this->write("atomic_fetch_add_explicit(&");
1311 this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1312 this->write(", ");
1313 this->writeExpression(*c.arguments()[1], Precedence::kSequence);
1314 this->write(", memory_order_relaxed)");
1315 return true;
1316 case k_atomicLoad_IntrinsicKind:
1317 this->write("atomic_load_explicit(&");
1318 this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1319 this->write(", memory_order_relaxed)");
1320 return true;
1321 case k_atomicStore_IntrinsicKind:
1322 this->write("atomic_store_explicit(&");
1323 this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1324 this->write(", ");
1325 this->writeExpression(*c.arguments()[1], Precedence::kSequence);
1326 this->write(", memory_order_relaxed)");
1327 return true;
1328 default:
1329 return false;
1330 }
1331 }
1332
1333 // Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
1334 // Cells that don't exist in the source matrix will be populated with identity-matrix values.
assembleMatrixFromMatrix(const Type & sourceMatrix,int columns,int rows)1335 void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int columns, int rows) {
1336 SkASSERT(rows <= 4);
1337 SkASSERT(columns <= 4);
1338
1339 std::string matrixType = this->typeName(sourceMatrix.componentType());
1340
1341 const char* separator = "";
1342 for (int c = 0; c < columns; ++c) {
1343 fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1344 separator = "), ";
1345
1346 // Determine how many values to take from the source matrix for this row.
1347 int swizzleLength = 0;
1348 if (c < sourceMatrix.columns()) {
1349 swizzleLength = std::min<>(rows, sourceMatrix.rows());
1350 }
1351
1352 // Emit all the values from the source matrix row.
1353 bool firstItem;
1354 switch (swizzleLength) {
1355 case 0: firstItem = true; break;
1356 case 1: firstItem = false; fExtraFunctions.printf("x0[%d].x", c); break;
1357 case 2: firstItem = false; fExtraFunctions.printf("x0[%d].xy", c); break;
1358 case 3: firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c); break;
1359 case 4: firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
1360 default: SkUNREACHABLE;
1361 }
1362
1363 // Emit the placeholder identity-matrix cells.
1364 for (int r = swizzleLength; r < rows; ++r) {
1365 fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
1366 firstItem = false;
1367 }
1368 }
1369
1370 fExtraFunctions.writeText(")");
1371 }
1372
1373 // Assembles a matrix of type floatCxR by concatenating an arbitrary mix of values, named `x0`,
1374 // `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
assembleMatrixFromExpressions(const AnyConstructor & ctor,int columns,int rows)1375 void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
1376 int columns,
1377 int rows) {
1378 SkASSERT(rows <= 4);
1379 SkASSERT(columns <= 4);
1380
1381 std::string matrixType = this->typeName(ctor.type().componentType());
1382 size_t argIndex = 0;
1383 int argPosition = 0;
1384 auto args = ctor.argumentSpan();
1385
1386 static constexpr char kSwizzle[] = "xyzw";
1387 const char* separator = "";
1388 for (int c = 0; c < columns; ++c) {
1389 fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1390 separator = "), ";
1391
1392 const char* columnSeparator = "";
1393 for (int r = 0; r < rows;) {
1394 fExtraFunctions.writeText(columnSeparator);
1395 columnSeparator = ", ";
1396
1397 if (argIndex < args.size()) {
1398 const Type& argType = args[argIndex]->type();
1399 switch (argType.typeKind()) {
1400 case Type::TypeKind::kScalar: {
1401 fExtraFunctions.printf("x%zu", argIndex);
1402 ++r;
1403 ++argPosition;
1404 break;
1405 }
1406 case Type::TypeKind::kVector: {
1407 fExtraFunctions.printf("x%zu.", argIndex);
1408 do {
1409 fExtraFunctions.write8(kSwizzle[argPosition]);
1410 ++r;
1411 ++argPosition;
1412 } while (r < rows && argPosition < argType.columns());
1413 break;
1414 }
1415 case Type::TypeKind::kMatrix: {
1416 fExtraFunctions.printf("x%zu[%d].", argIndex, argPosition / argType.rows());
1417 do {
1418 fExtraFunctions.write8(kSwizzle[argPosition]);
1419 ++r;
1420 ++argPosition;
1421 } while (r < rows && (argPosition % argType.rows()) != 0);
1422 break;
1423 }
1424 default: {
1425 SkDEBUGFAIL("incorrect type of argument for matrix constructor");
1426 fExtraFunctions.writeText("<error>");
1427 break;
1428 }
1429 }
1430
1431 if (argPosition >= argType.columns() * argType.rows()) {
1432 ++argIndex;
1433 argPosition = 0;
1434 }
1435 } else {
1436 SkDEBUGFAIL("not enough arguments for matrix constructor");
1437 fExtraFunctions.writeText("<error>");
1438 }
1439 }
1440 }
1441
1442 if (argPosition != 0 || argIndex != args.size()) {
1443 SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
1444 fExtraFunctions.writeText(", <error>");
1445 }
1446
1447 fExtraFunctions.writeText(")");
1448 }
1449
1450 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
1451 // Keeps track of previously generated constructors so that we won't generate more than one
1452 // constructor for any given permutation of input argument types. Returns the name of the
1453 // generated constructor method.
getMatrixConstructHelper(const AnyConstructor & c)1454 std::string MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
1455 const Type& type = c.type();
1456 int columns = type.columns();
1457 int rows = type.rows();
1458 auto args = c.argumentSpan();
1459 std::string typeName = this->typeName(type);
1460
1461 // Create the helper-method name and use it as our lookup key.
1462 std::string name = String::printf("%s_from", typeName.c_str());
1463 for (const std::unique_ptr<Expression>& expr : args) {
1464 String::appendf(&name, "_%s", this->typeName(expr->type()).c_str());
1465 }
1466
1467 // If a helper-method has not been synthesized yet, create it now.
1468 if (!fHelpers.contains(name)) {
1469 fHelpers.add(name);
1470
1471 // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
1472 // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
1473 // supply a mixture of scalars and vectors.)
1474 fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
1475
1476 size_t argIndex = 0;
1477 const char* argSeparator = "";
1478 for (const std::unique_ptr<Expression>& expr : args) {
1479 fExtraFunctions.printf("%s%s x%zu", argSeparator,
1480 this->typeName(expr->type()).c_str(), argIndex++);
1481 argSeparator = ", ";
1482 }
1483
1484 fExtraFunctions.printf(") {\n return %s(", typeName.c_str());
1485
1486 if (args.size() == 1 && args.front()->type().isMatrix()) {
1487 this->assembleMatrixFromMatrix(args.front()->type(), columns, rows);
1488 } else {
1489 this->assembleMatrixFromExpressions(c, columns, rows);
1490 }
1491
1492 fExtraFunctions.writeText(");\n}\n");
1493 }
1494 return name;
1495 }
1496
matrixConstructHelperIsNeeded(const ConstructorCompound & c)1497 bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const ConstructorCompound& c) {
1498 SkASSERT(c.type().isMatrix());
1499
1500 // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it
1501 // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C
1502 // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of
1503 // scalars can be constructed trivially. In more complex cases, we generate a helper function
1504 // that converts our inputs into a properly-shaped matrix.
1505 // A matrix construct helper method is always used if any input argument is a matrix.
1506 // Helper methods are also necessary when any argument would span multiple rows. For instance:
1507 //
1508 // float2 x = (1, 2);
1509 // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline
1510 // | 2 4 6 |
1511 //
1512 // float2 x = (2, 3);
1513 // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used
1514 // | 2 4 6 |
1515 //
1516 // float4 x = (1, 2, 3, 4);
1517 // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used
1518 // | 2 4 |
1519 //
1520
1521 int position = 0;
1522 for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1523 // If an input argument is a matrix, we need a helper function.
1524 if (expr->type().isMatrix()) {
1525 return true;
1526 }
1527 position += expr->type().columns();
1528 if (position > c.type().rows()) {
1529 // An input argument would span multiple rows; a helper function is required.
1530 return true;
1531 }
1532 if (position == c.type().rows()) {
1533 // We've advanced to the end of a row. Wrap to the start of the next row.
1534 position = 0;
1535 }
1536 }
1537
1538 return false;
1539 }
1540
writeConstructorMatrixResize(const ConstructorMatrixResize & c,Precedence parentPrecedence)1541 void MetalCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1542 Precedence parentPrecedence) {
1543 // Matrix-resize via casting doesn't natively exist in Metal at all, so we always need to use a
1544 // matrix-construct helper here.
1545 this->write(this->getMatrixConstructHelper(c));
1546 this->write("(");
1547 this->writeExpression(*c.argument(), Precedence::kSequence);
1548 this->write(")");
1549 }
1550
writeConstructorCompound(const ConstructorCompound & c,Precedence parentPrecedence)1551 void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1552 Precedence parentPrecedence) {
1553 if (c.type().isVector()) {
1554 this->writeConstructorCompoundVector(c, parentPrecedence);
1555 } else if (c.type().isMatrix()) {
1556 this->writeConstructorCompoundMatrix(c, parentPrecedence);
1557 } else {
1558 fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
1559 }
1560 }
1561
writeConstructorArrayCast(const ConstructorArrayCast & c,Precedence parentPrecedence)1562 void MetalCodeGenerator::writeConstructorArrayCast(const ConstructorArrayCast& c,
1563 Precedence parentPrecedence) {
1564 const Type& inType = c.argument()->type().componentType();
1565 const Type& outType = c.type().componentType();
1566 std::string inTypeName = this->typeName(inType);
1567 std::string outTypeName = this->typeName(outType);
1568
1569 std::string name = "array_of_" + outTypeName + "_from_" + inTypeName;
1570 if (!fHelpers.contains(name)) {
1571 fHelpers.add(name);
1572 fExtraFunctions.printf(R"(
1573 template <size_t N>
1574 array<%s, N> %s(thread const array<%s, N>& x) {
1575 array<%s, N> result;
1576 for (int i = 0; i < N; ++i) {
1577 result[i] = %s(x[i]);
1578 }
1579 return result;
1580 }
1581 )",
1582 outTypeName.c_str(), name.c_str(), inTypeName.c_str(),
1583 outTypeName.c_str(),
1584 outTypeName.c_str());
1585 }
1586
1587 this->write(name);
1588 this->write("(");
1589 this->writeExpression(*c.argument(), Precedence::kSequence);
1590 this->write(")");
1591 }
1592
getVectorFromMat2x2ConstructorHelper(const Type & matrixType)1593 std::string MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
1594 SkASSERT(matrixType.isMatrix());
1595 SkASSERT(matrixType.rows() == 2);
1596 SkASSERT(matrixType.columns() == 2);
1597
1598 std::string baseType = this->typeName(matrixType.componentType());
1599 std::string name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
1600 if (!fHelpers.contains(name)) {
1601 fHelpers.add(name);
1602
1603 fExtraFunctions.printf(R"(
1604 %s4 %s(%s2x2 x) {
1605 return %s4(x[0].xy, x[1].xy);
1606 }
1607 )", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
1608 }
1609
1610 return name;
1611 }
1612
writeConstructorCompoundVector(const ConstructorCompound & c,Precedence parentPrecedence)1613 void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1614 Precedence parentPrecedence) {
1615 SkASSERT(c.type().isVector());
1616
1617 // Metal supports constructing vectors from a mix of scalars and vectors, but not matrices.
1618 // GLSL supports vec4(mat2x2), so we detect that case here and emit a helper function.
1619 if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1620 const Expression& expr = *c.argumentSpan().front();
1621 if (expr.type().isMatrix()) {
1622 this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
1623 this->write("(");
1624 this->writeExpression(expr, Precedence::kSequence);
1625 this->write(")");
1626 return;
1627 }
1628 }
1629
1630 this->writeAnyConstructor(c, "(", ")", parentPrecedence);
1631 }
1632
writeConstructorCompoundMatrix(const ConstructorCompound & c,Precedence parentPrecedence)1633 void MetalCodeGenerator::writeConstructorCompoundMatrix(const ConstructorCompound& c,
1634 Precedence parentPrecedence) {
1635 SkASSERT(c.type().isMatrix());
1636
1637 // Emit and invoke a matrix-constructor helper method if one is necessary.
1638 if (this->matrixConstructHelperIsNeeded(c)) {
1639 this->write(this->getMatrixConstructHelper(c));
1640 this->write("(");
1641 const char* separator = "";
1642 for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1643 this->write(separator);
1644 separator = ", ";
1645 this->writeExpression(*expr, Precedence::kSequence);
1646 }
1647 this->write(")");
1648 return;
1649 }
1650
1651 // Metal doesn't allow creating matrices by passing in scalars and vectors in a jumble; it
1652 // requires your scalars to be grouped up into columns. Because `matrixConstructHelperIsNeeded`
1653 // returned false, we know that none of our scalars/vectors "wrap" across across a column, so we
1654 // can group our inputs up and synthesize a constructor for each column.
1655 const Type& matrixType = c.type();
1656 const Type& columnType = matrixType.columnType(fContext);
1657
1658 this->writeType(matrixType);
1659 this->write("(");
1660 const char* separator = "";
1661 int scalarCount = 0;
1662 for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1663 this->write(separator);
1664 separator = ", ";
1665 if (arg->type().columns() < matrixType.rows()) {
1666 // Write a `floatN(` constructor to group scalars and smaller vectors together.
1667 if (!scalarCount) {
1668 this->writeType(columnType);
1669 this->write("(");
1670 }
1671 scalarCount += arg->type().columns();
1672 }
1673 this->writeExpression(*arg, Precedence::kSequence);
1674 if (scalarCount && scalarCount == matrixType.rows()) {
1675 // Close our `floatN(...` constructor block from above.
1676 this->write(")");
1677 scalarCount = 0;
1678 }
1679 }
1680 this->write(")");
1681 }
1682
writeAnyConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1683 void MetalCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
1684 const char* leftBracket,
1685 const char* rightBracket,
1686 Precedence parentPrecedence) {
1687 this->writeType(c.type());
1688 this->write(leftBracket);
1689 const char* separator = "";
1690 for (const std::unique_ptr<Expression>& arg : c.argumentSpan()) {
1691 this->write(separator);
1692 separator = ", ";
1693 this->writeExpression(*arg, Precedence::kSequence);
1694 }
1695 this->write(rightBracket);
1696 }
1697
writeCastConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1698 void MetalCodeGenerator::writeCastConstructor(const AnyConstructor& c,
1699 const char* leftBracket,
1700 const char* rightBracket,
1701 Precedence parentPrecedence) {
1702 return this->writeAnyConstructor(c, leftBracket, rightBracket, parentPrecedence);
1703 }
1704
writeFragCoord()1705 void MetalCodeGenerator::writeFragCoord() {
1706 if (!fRTFlipName.empty()) {
1707 this->write("float4(_fragCoord.x, ");
1708 this->write(fRTFlipName.c_str());
1709 this->write(".x + ");
1710 this->write(fRTFlipName.c_str());
1711 this->write(".y * _fragCoord.y, 0.0, _fragCoord.w)");
1712 } else {
1713 this->write("float4(_fragCoord.x, _fragCoord.y, 0.0, _fragCoord.w)");
1714 }
1715 }
1716
is_compute_builtin(const Variable & var)1717 static bool is_compute_builtin(const Variable& var) {
1718 switch (var.layout().fBuiltin) {
1719 case SK_NUMWORKGROUPS_BUILTIN:
1720 case SK_WORKGROUPID_BUILTIN:
1721 case SK_LOCALINVOCATIONID_BUILTIN:
1722 case SK_GLOBALINVOCATIONID_BUILTIN:
1723 case SK_LOCALINVOCATIONINDEX_BUILTIN:
1724 return true;
1725 default:
1726 break;
1727 }
1728 return false;
1729 }
1730
1731 // true if the var is part of the Inputs struct
is_input(const Variable & var)1732 static bool is_input(const Variable& var) {
1733 SkASSERT(var.storage() == VariableStorage::kGlobal);
1734 return var.modifierFlags() & ModifierFlag::kIn &&
1735 (var.layout().fBuiltin == -1 || is_compute_builtin(var)) &&
1736 var.type().typeKind() != Type::TypeKind::kTexture;
1737 }
1738
1739 // true if the var is part of the Outputs struct
is_output(const Variable & var)1740 static bool is_output(const Variable& var) {
1741 SkASSERT(var.storage() == VariableStorage::kGlobal);
1742 // inout vars get written into the Inputs struct, so we exclude them from Outputs
1743 return (var.modifierFlags() & ModifierFlag::kOut) &&
1744 !(var.modifierFlags() & ModifierFlag::kIn) &&
1745 var.layout().fBuiltin == -1 &&
1746 var.type().typeKind() != Type::TypeKind::kTexture;
1747 }
1748
1749 // true if the var is part of the Uniforms struct
is_uniforms(const Variable & var)1750 static bool is_uniforms(const Variable& var) {
1751 SkASSERT(var.storage() == VariableStorage::kGlobal);
1752 return var.modifierFlags().isUniform() &&
1753 var.type().typeKind() != Type::TypeKind::kSampler;
1754 }
1755
1756 // true if the var is part of the Threadgroups struct
is_threadgroup(const Variable & var)1757 static bool is_threadgroup(const Variable& var) {
1758 SkASSERT(var.storage() == VariableStorage::kGlobal);
1759 return var.modifierFlags().isWorkgroup();
1760 }
1761
1762 // true if the var is part of the Globals struct
is_in_globals(const Variable & var)1763 static bool is_in_globals(const Variable& var) {
1764 SkASSERT(var.storage() == VariableStorage::kGlobal);
1765 return !var.modifierFlags().isConst();
1766 }
1767
writeVariableReference(const VariableReference & ref)1768 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
1769 switch (ref.variable()->layout().fBuiltin) {
1770 case SK_FRAGCOLOR_BUILTIN:
1771 this->write("_out.sk_FragColor");
1772 break;
1773 case SK_SAMPLEMASK_BUILTIN:
1774 this->write("_out.sk_SampleMask");
1775 break;
1776 case SK_SECONDARYFRAGCOLOR_BUILTIN:
1777 if (fCaps.fDualSourceBlendingSupport) {
1778 this->write("_out.sk_SecondaryFragColor");
1779 } else {
1780 fContext.fErrors->error(ref.position(), "'sk_SecondaryFragColor' not supported");
1781 }
1782 break;
1783 case SK_FRAGCOORD_BUILTIN:
1784 this->writeFragCoord();
1785 break;
1786 case SK_SAMPLEMASKIN_BUILTIN:
1787 this->write("sk_SampleMaskIn");
1788 break;
1789 case SK_VERTEXID_BUILTIN:
1790 this->write("sk_VertexID");
1791 break;
1792 case SK_INSTANCEID_BUILTIN:
1793 this->write("sk_InstanceID");
1794 break;
1795 case SK_CLOCKWISE_BUILTIN:
1796 // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
1797 // clockwise to match Skia convention.
1798 if (!fRTFlipName.empty()) {
1799 this->write("(" + fRTFlipName + ".y < 0 ? _frontFacing : !_frontFacing)");
1800 } else {
1801 this->write("_frontFacing");
1802 }
1803 break;
1804 case SK_LASTFRAGCOLOR_BUILTIN:
1805 if (fCaps.fFBFetchColorName) {
1806 this->write(fCaps.fFBFetchColorName);
1807 } else {
1808 fContext.fErrors->error(ref.position(), "'sk_LastFragColor' not supported");
1809 }
1810 break;
1811 default:
1812 const Variable& var = *ref.variable();
1813 if (var.storage() == Variable::Storage::kGlobal) {
1814 if (is_input(var)) {
1815 this->write("_in.");
1816 } else if (is_output(var)) {
1817 this->write("_out.");
1818 } else if (is_uniforms(var)) {
1819 this->write("_uniforms.");
1820 } else if (is_threadgroup(var)) {
1821 this->write("_threadgroups.");
1822 } else if (is_in_globals(var)) {
1823 this->write("_globals.");
1824 }
1825 }
1826 this->writeName(var.mangledName());
1827 }
1828 }
1829
writeIndexInnerExpression(const Expression & expr)1830 void MetalCodeGenerator::writeIndexInnerExpression(const Expression& expr) {
1831 if (fIndexSubstitutionData) {
1832 // If this expression already exists in the index-substitution map, use the substitute.
1833 if (const std::string* existing = fIndexSubstitutionData->fMap.find(&expr)) {
1834 this->write(*existing);
1835 return;
1836 }
1837
1838 // If this expression is non-trivial, we will need to create a scratch variable and store
1839 // its value there.
1840 if (fIndexSubstitutionData->fCreateSubstitutes && !Analysis::IsTrivialExpression(expr)) {
1841 // Create a substitute variable and emit it into the main stream.
1842 std::string scratchVar = this->getTempVariable(expr.type());
1843 this->write(scratchVar);
1844
1845 // Initialize the substitute variable in the prefix-stream.
1846 AutoOutputStream outputToPrefixStream(this, &fIndexSubstitutionData->fPrefixStream);
1847 this->write(scratchVar);
1848 this->write(" = ");
1849 this->writeExpression(expr, Precedence::kAssignment);
1850 this->write(", ");
1851
1852 // Remember the substitute variable in our map.
1853 fIndexSubstitutionData->fMap.set(&expr, std::move(scratchVar));
1854 return;
1855 }
1856 }
1857
1858 // We don't require index-substitution; just emit the expression normally.
1859 this->writeExpression(expr, Precedence::kExpression);
1860 }
1861
writeIndexExpression(const IndexExpression & expr)1862 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
1863 // Metal does not seem to handle assignment into `vec.zyx[i]` properly--it compiles, but the
1864 // results are wrong. We rewrite the expression as `vec[uint3(2,1,0)[i]]` instead. (Filed with
1865 // Apple as FB12055941.)
1866 if (expr.base()->is<Swizzle>() && expr.base()->as<Swizzle>().components().size() > 1) {
1867 const Swizzle& swizzle = expr.base()->as<Swizzle>();
1868 this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1869 this->write("[uint" + std::to_string(swizzle.components().size()) + "(");
1870 auto separator = SkSL::String::Separator();
1871 for (int8_t component : swizzle.components()) {
1872 this->write(separator());
1873 this->write(std::to_string(component));
1874 }
1875 this->write(")[");
1876 this->writeIndexInnerExpression(*expr.index());
1877 this->write("]]");
1878 } else {
1879 this->writeExpression(*expr.base(), Precedence::kPostfix);
1880 this->write("[");
1881 this->writeIndexInnerExpression(*expr.index());
1882 this->write("]");
1883 }
1884 }
1885
writeFieldAccess(const FieldAccess & f)1886 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
1887 const Field* field = &f.base()->type().fields()[f.fieldIndex()];
1888 if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
1889 this->writeExpression(*f.base(), Precedence::kPostfix);
1890 this->write(".");
1891 }
1892 switch (field->fLayout.fBuiltin) {
1893 case SK_POSITION_BUILTIN:
1894 this->write("_out.sk_Position");
1895 break;
1896 case SK_POINTSIZE_BUILTIN:
1897 this->write("_out.sk_PointSize");
1898 break;
1899 default:
1900 if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
1901 this->write("_globals.");
1902 this->write(fInterfaceBlockNameMap[&f.base()->type()]);
1903 this->write("->");
1904 }
1905 this->writeName(field->fName);
1906 }
1907 }
1908
writeSwizzle(const Swizzle & swizzle)1909 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1910 this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1911 this->write(".");
1912 this->write(Swizzle::MaskString(swizzle.components()));
1913 }
1914
writeMatrixTimesEqualHelper(const Type & left,const Type & right,const Type & result)1915 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
1916 const Type& result) {
1917 SkASSERT(left.isMatrix());
1918 SkASSERT(right.isMatrix());
1919 SkASSERT(result.isMatrix());
1920
1921 std::string key = "Matrix *= " + this->typeName(left) + ":" + this->typeName(right);
1922
1923 if (!fHelpers.contains(key)) {
1924 fHelpers.add(key);
1925 fExtraFunctions.printf("thread %s& operator*=(thread %s& left, thread const %s& right) {\n"
1926 " left = left * right;\n"
1927 " return left;\n"
1928 "}\n",
1929 this->typeName(result).c_str(), this->typeName(left).c_str(),
1930 this->typeName(right).c_str());
1931 }
1932 }
1933
writeMatrixEqualityHelpers(const Type & left,const Type & right)1934 void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
1935 SkASSERT(left.isMatrix());
1936 SkASSERT(right.isMatrix());
1937 SkASSERT(left.rows() == right.rows());
1938 SkASSERT(left.columns() == right.columns());
1939
1940 std::string key = "Matrix == " + this->typeName(left) + ":" + this->typeName(right);
1941
1942 if (!fHelpers.contains(key)) {
1943 fHelpers.add(key);
1944 fExtraFunctionPrototypes.printf(R"(
1945 thread bool operator==(const %s left, const %s right);
1946 thread bool operator!=(const %s left, const %s right);
1947 )",
1948 this->typeName(left).c_str(),
1949 this->typeName(right).c_str(),
1950 this->typeName(left).c_str(),
1951 this->typeName(right).c_str());
1952
1953 fExtraFunctions.printf(
1954 "thread bool operator==(const %s left, const %s right) {\n"
1955 " return ",
1956 this->typeName(left).c_str(), this->typeName(right).c_str());
1957
1958 const char* separator = "";
1959 for (int index=0; index<left.columns(); ++index) {
1960 fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
1961 separator = " &&\n ";
1962 }
1963
1964 fExtraFunctions.printf(
1965 ";\n"
1966 "}\n"
1967 "thread bool operator!=(const %s left, const %s right) {\n"
1968 " return !(left == right);\n"
1969 "}\n",
1970 this->typeName(left).c_str(), this->typeName(right).c_str());
1971 }
1972 }
1973
writeMatrixDivisionHelpers(const Type & type)1974 void MetalCodeGenerator::writeMatrixDivisionHelpers(const Type& type) {
1975 SkASSERT(type.isMatrix());
1976
1977 std::string key = "Matrix / " + this->typeName(type);
1978
1979 if (!fHelpers.contains(key)) {
1980 fHelpers.add(key);
1981 std::string typeName = this->typeName(type);
1982
1983 fExtraFunctions.printf(
1984 "thread %s operator/(const %s left, const %s right) {\n"
1985 " return %s(",
1986 typeName.c_str(), typeName.c_str(), typeName.c_str(), typeName.c_str());
1987
1988 const char* separator = "";
1989 for (int index=0; index<type.columns(); ++index) {
1990 fExtraFunctions.printf("%sleft[%d] / right[%d]", separator, index, index);
1991 separator = ", ";
1992 }
1993
1994 fExtraFunctions.printf(");\n"
1995 "}\n"
1996 "thread %s& operator/=(thread %s& left, thread const %s& right) {\n"
1997 " left = left / right;\n"
1998 " return left;\n"
1999 "}\n",
2000 typeName.c_str(), typeName.c_str(), typeName.c_str());
2001 }
2002 }
2003
writeArrayEqualityHelpers(const Type & type)2004 void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
2005 SkASSERT(type.isArray());
2006
2007 // If the array's component type needs a helper as well, we need to emit that one first.
2008 this->writeEqualityHelpers(type.componentType(), type.componentType());
2009
2010 std::string key = "ArrayEquality []";
2011 if (!fHelpers.contains(key)) {
2012 fHelpers.add(key);
2013 fExtraFunctionPrototypes.writeText(R"(
2014 template <typename T1, typename T2>
2015 bool operator==(const array_ref<T1> left, const array_ref<T2> right);
2016 template <typename T1, typename T2>
2017 bool operator!=(const array_ref<T1> left, const array_ref<T2> right);
2018 )");
2019 fExtraFunctions.writeText(R"(
2020 template <typename T1, typename T2>
2021 bool operator==(const array_ref<T1> left, const array_ref<T2> right) {
2022 if (left.size() != right.size()) {
2023 return false;
2024 }
2025 for (size_t index = 0; index < left.size(); ++index) {
2026 if (!all(left[index] == right[index])) {
2027 return false;
2028 }
2029 }
2030 return true;
2031 }
2032
2033 template <typename T1, typename T2>
2034 bool operator!=(const array_ref<T1> left, const array_ref<T2> right) {
2035 return !(left == right);
2036 }
2037 )");
2038 }
2039 }
2040
writeStructEqualityHelpers(const Type & type)2041 void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
2042 SkASSERT(type.isStruct());
2043 std::string key = "StructEquality " + this->typeName(type);
2044
2045 if (!fHelpers.contains(key)) {
2046 fHelpers.add(key);
2047 // If one of the struct's fields needs a helper as well, we need to emit that one first.
2048 for (const Field& field : type.fields()) {
2049 this->writeEqualityHelpers(*field.fType, *field.fType);
2050 }
2051
2052 // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
2053 // and GLSL but do not exist by default in Metal.
2054 fExtraFunctionPrototypes.printf(R"(
2055 thread bool operator==(thread const %s& left, thread const %s& right);
2056 thread bool operator!=(thread const %s& left, thread const %s& right);
2057 )",
2058 this->typeName(type).c_str(),
2059 this->typeName(type).c_str(),
2060 this->typeName(type).c_str(),
2061 this->typeName(type).c_str());
2062
2063 fExtraFunctions.printf(
2064 "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
2065 " return ",
2066 this->typeName(type).c_str(),
2067 this->typeName(type).c_str());
2068
2069 const char* separator = "";
2070 for (const Field& field : type.fields()) {
2071 if (field.fType->isArray()) {
2072 fExtraFunctions.printf(
2073 "%s(make_array_ref(left.%.*s) == make_array_ref(right.%.*s))",
2074 separator,
2075 (int)field.fName.size(), field.fName.data(),
2076 (int)field.fName.size(), field.fName.data());
2077 } else {
2078 fExtraFunctions.printf("%sall(left.%.*s == right.%.*s)",
2079 separator,
2080 (int)field.fName.size(), field.fName.data(),
2081 (int)field.fName.size(), field.fName.data());
2082 }
2083 separator = " &&\n ";
2084 }
2085 fExtraFunctions.printf(
2086 ";\n"
2087 "}\n"
2088 "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
2089 " return !(left == right);\n"
2090 "}\n",
2091 this->typeName(type).c_str(),
2092 this->typeName(type).c_str());
2093 }
2094 }
2095
writeEqualityHelpers(const Type & leftType,const Type & rightType)2096 void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
2097 if (leftType.isArray() && rightType.isArray()) {
2098 this->writeArrayEqualityHelpers(leftType);
2099 return;
2100 }
2101 if (leftType.isStruct() && rightType.isStruct()) {
2102 this->writeStructEqualityHelpers(leftType);
2103 return;
2104 }
2105 if (leftType.isMatrix() && rightType.isMatrix()) {
2106 this->writeMatrixEqualityHelpers(leftType, rightType);
2107 return;
2108 }
2109 }
2110
splatMatrixOf1(const Type & type)2111 std::string MetalCodeGenerator::splatMatrixOf1(const Type& type) {
2112 std::string str = this->typeName(type) + '(';
2113
2114 auto separator = SkSL::String::Separator();
2115 for (int index = type.slotCount(); index--;) {
2116 str += separator();
2117 str += "1.0";
2118 }
2119
2120 return str + ')';
2121 }
2122
writeNumberAsMatrix(const Expression & expr,const Type & matrixType)2123 void MetalCodeGenerator::writeNumberAsMatrix(const Expression& expr, const Type& matrixType) {
2124 SkASSERT(expr.type().isNumber());
2125 SkASSERT(matrixType.isMatrix());
2126
2127 // Componentwise multiply the scalar against a matrix of the desired size which contains all 1s.
2128 this->write("(");
2129 this->write(this->splatMatrixOf1(matrixType));
2130 this->write(" * ");
2131 this->writeExpression(expr, Precedence::kMultiplicative);
2132 this->write(")");
2133 }
2134
writeBinaryExpressionElement(const Expression & expr,Operator op,const Expression & other,Precedence precedence)2135 void MetalCodeGenerator::writeBinaryExpressionElement(const Expression& expr,
2136 Operator op,
2137 const Expression& other,
2138 Precedence precedence) {
2139 bool needMatrixSplatOnScalar = other.type().isMatrix() && expr.type().isNumber() &&
2140 op.isValidForMatrixOrVector() &&
2141 op.removeAssignment().kind() != Operator::Kind::STAR;
2142 if (needMatrixSplatOnScalar) {
2143 this->writeNumberAsMatrix(expr, other.type());
2144 } else if (op.isEquality() && expr.type().isArray()) {
2145 this->write("make_array_ref(");
2146 this->writeExpression(expr, precedence);
2147 this->write(")");
2148 } else {
2149 this->writeExpression(expr, precedence);
2150 }
2151 }
2152
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)2153 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
2154 Precedence parentPrecedence) {
2155 const Expression& left = *b.left();
2156 const Expression& right = *b.right();
2157 const Type& leftType = left.type();
2158 const Type& rightType = right.type();
2159 Operator op = b.getOperator();
2160 Precedence precedence = op.getBinaryPrecedence();
2161 bool needParens = precedence >= parentPrecedence;
2162 switch (op.kind()) {
2163 case Operator::Kind::EQEQ:
2164 this->writeEqualityHelpers(leftType, rightType);
2165 if (leftType.isVector()) {
2166 this->write("all");
2167 needParens = true;
2168 }
2169 break;
2170 case Operator::Kind::NEQ:
2171 this->writeEqualityHelpers(leftType, rightType);
2172 if (leftType.isVector()) {
2173 this->write("any");
2174 needParens = true;
2175 }
2176 break;
2177 default:
2178 break;
2179 }
2180 if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Operator::Kind::STAREQ) {
2181 this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
2182 }
2183 if (op.removeAssignment().kind() == Operator::Kind::SLASH &&
2184 ((leftType.isMatrix() && rightType.isMatrix()) ||
2185 (leftType.isScalar() && rightType.isMatrix()) ||
2186 (leftType.isMatrix() && rightType.isScalar()))) {
2187 this->writeMatrixDivisionHelpers(leftType.isMatrix() ? leftType : rightType);
2188 }
2189
2190 if (needParens) {
2191 this->write("(");
2192 }
2193
2194 // Some expressions need to be rewritten from `lhs *= rhs` to `lhs = lhs * rhs`, e.g.:
2195 // float4 x = float4(1);
2196 // x.xy *= float2x2(...);
2197 // will report the error "non-const reference cannot bind to vector element."
2198 if (op.isCompoundAssignment() && left.kind() == Expression::Kind::kSwizzle) {
2199 // We need to do the rewrite. This could be dangerous if the lhs contains an index
2200 // expression with a side effect (such as `array[Func()]`), so we enable index-substitution
2201 // here for the LHS; any index-expression with side effects will be evaluated into a scratch
2202 // variable.
2203 this->writeWithIndexSubstitution([&] {
2204 this->writeExpression(left, precedence);
2205 this->write(" = ");
2206 this->writeExpression(left, Precedence::kAssignment);
2207 this->write(operator_name(op.removeAssignment()));
2208
2209 // We never want to create index-expression substitutes on the RHS of the expression;
2210 // the RHS is only emitted one time.
2211 fIndexSubstitutionData->fCreateSubstitutes = false;
2212
2213 this->writeBinaryExpressionElement(right, op, left,
2214 op.removeAssignment().getBinaryPrecedence());
2215 });
2216 } else {
2217 // We don't need any rewrite; emit the binary expression as-is.
2218 this->writeBinaryExpressionElement(left, op, right, precedence);
2219 this->write(operator_name(op));
2220 this->writeBinaryExpressionElement(right, op, left, precedence);
2221 }
2222
2223 if (needParens) {
2224 this->write(")");
2225 }
2226 }
2227
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)2228 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
2229 Precedence parentPrecedence) {
2230 if (Precedence::kTernary >= parentPrecedence) {
2231 this->write("(");
2232 }
2233 this->writeExpression(*t.test(), Precedence::kTernary);
2234 this->write(" ? ");
2235 this->writeExpression(*t.ifTrue(), Precedence::kTernary);
2236 this->write(" : ");
2237 this->writeExpression(*t.ifFalse(), Precedence::kTernary);
2238 if (Precedence::kTernary >= parentPrecedence) {
2239 this->write(")");
2240 }
2241 }
2242
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)2243 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
2244 Precedence parentPrecedence) {
2245 const Operator op = p.getOperator();
2246 switch (op.kind()) {
2247 case Operator::Kind::PLUS:
2248 // According to the MSL specification, the arithmetic unary operators (+ and –) do not
2249 // act upon matrix-typed operands. We treat the unary "+" as a no-op for all operands.
2250 this->writeExpression(*p.operand(), Precedence::kPrefix);
2251 return;
2252
2253 case Operator::Kind::MINUS:
2254 // Transform the unary `-` on a matrix type to a multiplication by -1.
2255 if (p.operand()->type().isMatrix()) {
2256 this->write(p.type().componentType().highPrecision() ? "(-1.0 * "
2257 : "(-1.0h * ");
2258 this->writeExpression(*p.operand(), Precedence::kMultiplicative);
2259 this->write(")");
2260 return;
2261 }
2262 break;
2263
2264 case Operator::Kind::PLUSPLUS:
2265 case Operator::Kind::MINUSMINUS:
2266 if (p.operand()->type().isMatrix()) {
2267 // Transform `++x` or `--x` on a matrix type to `mat += T(1.0, ...)` or
2268 // `mat -= T(1.0, ...)`.
2269 this->write("(");
2270 this->writeExpression(*p.operand(), Precedence::kAssignment);
2271 this->write(op.kind() == Operator::Kind::PLUSPLUS ? " += " : " -= ");
2272 this->write(this->splatMatrixOf1(p.operand()->type()));
2273 this->write(")");
2274 return;
2275 }
2276 break;
2277
2278 default:
2279 break;
2280 }
2281
2282 if (Precedence::kPrefix >= parentPrecedence) {
2283 this->write("(");
2284 }
2285
2286 this->write(op.tightOperatorName());
2287 this->writeExpression(*p.operand(), Precedence::kPrefix);
2288
2289 if (Precedence::kPrefix >= parentPrecedence) {
2290 this->write(")");
2291 }
2292 }
2293
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)2294 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
2295 Precedence parentPrecedence) {
2296 const Operator op = p.getOperator();
2297 switch (op.kind()) {
2298 case Operator::Kind::PLUSPLUS:
2299 case Operator::Kind::MINUSMINUS:
2300 if (p.operand()->type().isMatrix()) {
2301 // We need to transform `x++` or `x--` into `+=` and `-=` on a matrix.
2302 // Unfortunately, that requires making a temporary copy of the old value and
2303 // emitting a sequence expression: `((temp = mat), (mat += T(1.0, ...)), temp)`.
2304 std::string tempMatrix = this->getTempVariable(p.operand()->type());
2305 this->write("((");
2306 this->write(tempMatrix);
2307 this->write(" = ");
2308 this->writeExpression(*p.operand(), Precedence::kAssignment);
2309 this->write("), (");
2310 this->writeExpression(*p.operand(), Precedence::kAssignment);
2311 this->write(op.kind() == Operator::Kind::PLUSPLUS ? " += " : " -= ");
2312 this->write(this->splatMatrixOf1(p.operand()->type()));
2313 this->write("), ");
2314 this->write(tempMatrix);
2315 this->write(")");
2316 return;
2317 }
2318 break;
2319
2320 default:
2321 break;
2322 }
2323
2324 if (Precedence::kPostfix >= parentPrecedence) {
2325 this->write("(");
2326 }
2327 this->writeExpression(*p.operand(), Precedence::kPostfix);
2328 this->write(op.tightOperatorName());
2329 if (Precedence::kPostfix >= parentPrecedence) {
2330 this->write(")");
2331 }
2332 }
2333
writeLiteral(const Literal & l)2334 void MetalCodeGenerator::writeLiteral(const Literal& l) {
2335 const Type& type = l.type();
2336 if (type.isFloat()) {
2337 this->write(l.description(OperatorPrecedence::kExpression));
2338 if (!l.type().highPrecision()) {
2339 this->write("h");
2340 }
2341 return;
2342 }
2343 if (type.isInteger()) {
2344 if (type.matches(*fContext.fTypes.fUInt)) {
2345 this->write(std::to_string(l.intValue() & 0xffffffff));
2346 this->write("u");
2347 } else if (type.matches(*fContext.fTypes.fUShort)) {
2348 this->write(std::to_string(l.intValue() & 0xffff));
2349 this->write("u");
2350 } else {
2351 this->write(std::to_string(l.intValue()));
2352 }
2353 return;
2354 }
2355 SkASSERT(type.isBoolean());
2356 this->write(l.description(OperatorPrecedence::kExpression));
2357 }
2358
writeFunctionRequirementArgs(const FunctionDeclaration & f,const char * & separator)2359 void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
2360 const char*& separator) {
2361 Requirements requirements = this->requirements(f);
2362 if (requirements & kInputs_Requirement) {
2363 this->write(separator);
2364 this->write("_in");
2365 separator = ", ";
2366 }
2367 if (requirements & kOutputs_Requirement) {
2368 this->write(separator);
2369 this->write("_out");
2370 separator = ", ";
2371 }
2372 if (requirements & kUniforms_Requirement) {
2373 this->write(separator);
2374 this->write("_uniforms");
2375 separator = ", ";
2376 }
2377 if (requirements & kGlobals_Requirement) {
2378 this->write(separator);
2379 this->write("_globals");
2380 separator = ", ";
2381 }
2382 if (requirements & kFragCoord_Requirement) {
2383 this->write(separator);
2384 this->write("_fragCoord");
2385 separator = ", ";
2386 }
2387 if (requirements & kSampleMaskIn_Requirement) {
2388 this->write(separator);
2389 this->write("sk_SampleMaskIn");
2390 separator = ", ";
2391 }
2392 if (requirements & kVertexID_Requirement) {
2393 this->write(separator);
2394 this->write("sk_VertexID");
2395 separator = ", ";
2396 }
2397 if (requirements & kInstanceID_Requirement) {
2398 this->write(separator);
2399 this->write("sk_InstanceID");
2400 separator = ", ";
2401 }
2402 if (requirements & kThreadgroups_Requirement) {
2403 this->write(separator);
2404 this->write("_threadgroups");
2405 separator = ", ";
2406 }
2407 }
2408
writeFunctionRequirementParams(const FunctionDeclaration & f,const char * & separator)2409 void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
2410 const char*& separator) {
2411 Requirements requirements = this->requirements(f);
2412 if (requirements & kInputs_Requirement) {
2413 this->write(separator);
2414 this->write("Inputs _in");
2415 separator = ", ";
2416 }
2417 if (requirements & kOutputs_Requirement) {
2418 this->write(separator);
2419 this->write("thread Outputs& _out");
2420 separator = ", ";
2421 }
2422 if (requirements & kUniforms_Requirement) {
2423 this->write(separator);
2424 this->write("Uniforms _uniforms");
2425 separator = ", ";
2426 }
2427 if (requirements & kGlobals_Requirement) {
2428 this->write(separator);
2429 this->write("thread Globals& _globals");
2430 separator = ", ";
2431 }
2432 if (requirements & kFragCoord_Requirement) {
2433 this->write(separator);
2434 this->write("float4 _fragCoord");
2435 separator = ", ";
2436 }
2437 if (requirements & kSampleMaskIn_Requirement) {
2438 this->write(separator);
2439 this->write("uint sk_SampleMaskIn");
2440 separator = ", ";
2441 }
2442 if (requirements & kVertexID_Requirement) {
2443 this->write(separator);
2444 this->write("uint sk_VertexID");
2445 separator = ", ";
2446 }
2447 if (requirements & kInstanceID_Requirement) {
2448 this->write(separator);
2449 this->write("uint sk_InstanceID");
2450 separator = ", ";
2451 }
2452 if (requirements & kThreadgroups_Requirement) {
2453 this->write(separator);
2454 this->write("threadgroup Threadgroups& _threadgroups");
2455 separator = ", ";
2456 }
2457 }
2458
getUniformBinding(const Layout & layout)2459 int MetalCodeGenerator::getUniformBinding(const Layout& layout) {
2460 return (layout.fBinding >= 0) ? layout.fBinding
2461 : fProgram.fConfig->fSettings.fDefaultUniformBinding;
2462 }
2463
getUniformSet(const Layout & layout)2464 int MetalCodeGenerator::getUniformSet(const Layout& layout) {
2465 return (layout.fSet >= 0) ? layout.fSet
2466 : fProgram.fConfig->fSettings.fDefaultUniformSet;
2467 }
2468
writeFunctionDeclaration(const FunctionDeclaration & f)2469 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
2470 fRTFlipName = (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None)
2471 ? "_globals._anonInterface0->" SKSL_RTFLIP_NAME
2472 : "";
2473 const char* separator = "";
2474 if (f.isMain()) {
2475 if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2476 this->write("fragment Outputs fragmentMain(");
2477 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2478 this->write("vertex Outputs vertexMain(");
2479 } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2480 this->write("kernel void computeMain(");
2481 } else {
2482 fContext.fErrors->error(Position(), "unsupported kind of program");
2483 return false;
2484 }
2485 if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2486 this->write("Inputs _in [[stage_in]]");
2487 separator = ", ";
2488 }
2489 if (-1 != fUniformBuffer) {
2490 this->write(separator);
2491 this->write("constant Uniforms& _uniforms [[buffer(" +
2492 std::to_string(fUniformBuffer) + ")]]");
2493 separator = ", ";
2494 }
2495 for (const ProgramElement* e : fProgram.elements()) {
2496 if (e->is<GlobalVarDeclaration>()) {
2497 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2498 const VarDeclaration& decl = decls.varDeclaration();
2499 const Variable* var = decl.var();
2500 const SkSL::Type::TypeKind varKind = var->type().typeKind();
2501
2502 if (varKind == Type::TypeKind::kSampler || varKind == Type::TypeKind::kTexture) {
2503 if (var->type().dimensions() != SpvDim2D) {
2504 // Not yet implemented--Skia currently only uses 2D textures.
2505 fContext.fErrors->error(decls.fPosition, "Unsupported texture dimensions");
2506 return false;
2507 }
2508
2509 int binding = getUniformBinding(var->layout());
2510 this->write(separator);
2511 separator = ", ";
2512
2513 if (varKind == Type::TypeKind::kSampler) {
2514 this->writeType(var->type().textureType());
2515 this->write(" ");
2516 this->writeName(var->mangledName());
2517 this->write(kTextureSuffix);
2518 this->write(" [[texture(");
2519 this->write(std::to_string(binding));
2520 this->write(")]], sampler ");
2521 this->writeName(var->mangledName());
2522 this->write(kSamplerSuffix);
2523 this->write(" [[sampler(");
2524 this->write(std::to_string(binding));
2525 this->write(")]]");
2526 } else {
2527 SkASSERT(varKind == Type::TypeKind::kTexture);
2528 this->writeType(var->type());
2529 this->write(" ");
2530 this->writeName(var->mangledName());
2531 this->write(" [[texture(");
2532 this->write(std::to_string(binding));
2533 this->write(")]]");
2534 }
2535 } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2536 std::string_view attr;
2537 switch (var->layout().fBuiltin) {
2538 case SK_NUMWORKGROUPS_BUILTIN:
2539 attr = " [[threadgroups_per_grid]]";
2540 break;
2541 case SK_WORKGROUPID_BUILTIN:
2542 attr = " [[threadgroup_position_in_grid]]";
2543 break;
2544 case SK_LOCALINVOCATIONID_BUILTIN:
2545 attr = " [[thread_position_in_threadgroup]]";
2546 break;
2547 case SK_GLOBALINVOCATIONID_BUILTIN:
2548 attr = " [[thread_position_in_grid]]";
2549 break;
2550 case SK_LOCALINVOCATIONINDEX_BUILTIN:
2551 attr = " [[thread_index_in_threadgroup]]";
2552 break;
2553 default:
2554 break;
2555 }
2556 if (!attr.empty()) {
2557 this->write(separator);
2558 this->writeType(var->type());
2559 this->write(" ");
2560 this->write(var->name());
2561 this->write(attr);
2562 separator = ", ";
2563 }
2564 }
2565 } else if (e->is<InterfaceBlock>()) {
2566 const InterfaceBlock& intf = e->as<InterfaceBlock>();
2567 if (intf.typeName() == "sk_PerVertex") {
2568 continue;
2569 }
2570 this->write(separator);
2571 if (is_readonly(intf)) {
2572 this->write("const ");
2573 }
2574 this->write(is_buffer(intf) ? "device " : "constant ");
2575 this->writeType(intf.var()->type());
2576 this->write("& " );
2577 this->write(fInterfaceBlockNameMap[&intf.var()->type()]);
2578 this->write(" [[buffer(");
2579 this->write(std::to_string(this->getUniformBinding(intf.var()->layout())));
2580 this->write(")]]");
2581 separator = ", ";
2582 }
2583 }
2584 if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2585 if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None &&
2586 fInterfaceBlockNameMap.empty()) {
2587 this->write(separator);
2588 this->write("constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
2589 fRTFlipName = "_anonInterface0." SKSL_RTFLIP_NAME;
2590 separator = ", ";
2591 }
2592 this->write(separator);
2593 this->write("bool _frontFacing [[front_facing]], float4 _fragCoord [[position]]");
2594 if (this->requirements(f) & kSampleMaskIn_Requirement) {
2595 this->write(", uint sk_SampleMaskIn [[sample_mask]]");
2596 }
2597 if (fProgram.fInterface.fUseLastFragColor && fCaps.fFBFetchColorName) {
2598 this->write(", half4 " + std::string(fCaps.fFBFetchColorName) +
2599 " [[color(0)]]\n");
2600 }
2601 separator = ", ";
2602 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2603 this->write(separator);
2604 this->write("uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
2605 separator = ", ";
2606 }
2607 } else {
2608 this->writeType(f.returnType());
2609 this->write(" ");
2610 this->writeName(f.mangledName());
2611 this->write("(");
2612 this->writeFunctionRequirementParams(f, separator);
2613 }
2614 for (const Variable* param : f.parameters()) {
2615 // This is a workaround for our test files. They use the runtime effect signature, so main
2616 // takes a coords parameter. We detect these at IR generation time, and we omit them from
2617 // the declaration here, so the function is valid Metal. (Well, valid as long as the
2618 // coordinates aren't actually referenced.)
2619 if (f.isMain() && param == f.getMainCoordsParameter()) {
2620 continue;
2621 }
2622 this->write(separator);
2623 separator = ", ";
2624 this->writeModifiers(param->modifierFlags());
2625 this->writeType(param->type());
2626 if (pass_by_reference(param->type(), param->modifierFlags())) {
2627 this->write("&");
2628 }
2629 this->write(" ");
2630 this->writeName(param->mangledName());
2631 }
2632 this->write(")");
2633 return true;
2634 }
2635
writeFunctionPrototype(const FunctionPrototype & f)2636 void MetalCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
2637 this->writeFunctionDeclaration(f.declaration());
2638 this->writeLine(";");
2639 }
2640
is_block_ending_with_return(const Statement * stmt)2641 static bool is_block_ending_with_return(const Statement* stmt) {
2642 // This function detects (potentially nested) blocks that end in a return statement.
2643 if (!stmt->is<Block>()) {
2644 return false;
2645 }
2646 const StatementArray& block = stmt->as<Block>().children();
2647 for (int index = block.size(); index--; ) {
2648 stmt = block[index].get();
2649 if (stmt->is<ReturnStatement>()) {
2650 return true;
2651 }
2652 if (stmt->is<Block>()) {
2653 return is_block_ending_with_return(stmt);
2654 }
2655 if (!stmt->is<Nop>()) {
2656 break;
2657 }
2658 }
2659 return false;
2660 }
2661
writeComputeMainInputs()2662 void MetalCodeGenerator::writeComputeMainInputs() {
2663 // Compute shaders only have input variables (e.g. sk_GlobalInvocationID) and access program
2664 // inputs/outputs via the Globals and Uniforms structs. We collect the allowed "in" parameters
2665 // into an Input struct here, since the rest of the code expects the normal _in / _out pattern.
2666 this->write("Inputs _in = { ");
2667 const char* separator = "";
2668 for (const ProgramElement* e : fProgram.elements()) {
2669 if (e->is<GlobalVarDeclaration>()) {
2670 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2671 const Variable* var = decls.varDeclaration().var();
2672 if (is_input(*var)) {
2673 this->write(separator);
2674 separator = ", ";
2675 this->writeName(var->mangledName());
2676 }
2677 }
2678 }
2679 this->writeLine(" };");
2680 }
2681
writeFunction(const FunctionDefinition & f)2682 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
2683 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
2684
2685 if (!this->writeFunctionDeclaration(f.declaration())) {
2686 return;
2687 }
2688
2689 fCurrentFunction = &f.declaration();
2690 SkScopeExit clearCurrentFunction([&] { fCurrentFunction = nullptr; });
2691
2692 this->writeLine(" {");
2693
2694 if (f.declaration().isMain()) {
2695 fIndentation++;
2696 this->writeGlobalInit();
2697 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2698 this->writeThreadgroupInit();
2699 this->writeComputeMainInputs();
2700 }
2701 else {
2702 this->writeLine("Outputs _out;");
2703 this->writeLine("(void)_out;");
2704 }
2705 fIndentation--;
2706 }
2707
2708 fFunctionHeader.clear();
2709 StringStream buffer;
2710 {
2711 AutoOutputStream outputToBuffer(this, &buffer);
2712 fIndentation++;
2713 for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
2714 if (!stmt->isEmpty()) {
2715 this->writeStatement(*stmt);
2716 this->finishLine();
2717 }
2718 }
2719 if (f.declaration().isMain()) {
2720 // If the main function doesn't end with a return, we need to synthesize one here.
2721 if (!is_block_ending_with_return(f.body().get())) {
2722 this->writeReturnStatementFromMain();
2723 this->finishLine();
2724 }
2725 }
2726 fIndentation--;
2727 this->writeLine("}");
2728 }
2729 this->write(fFunctionHeader);
2730 this->write(buffer.str());
2731 }
2732
writeModifiers(ModifierFlags flags)2733 void MetalCodeGenerator::writeModifiers(ModifierFlags flags) {
2734 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
2735 (flags & (ModifierFlag::kIn | ModifierFlag::kOut))) {
2736 this->write("device ");
2737 } else if (flags & ModifierFlag::kOut) {
2738 this->write("thread ");
2739 }
2740 if (flags.isConst()) {
2741 this->write("const ");
2742 }
2743 }
2744
writeInterfaceBlock(const InterfaceBlock & intf)2745 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2746 if (intf.typeName() == "sk_PerVertex") {
2747 return;
2748 }
2749 const Type* structType = &intf.var()->type().componentType();
2750 this->writeModifiers(intf.var()->modifierFlags());
2751 this->write("struct ");
2752 this->writeType(*structType);
2753 this->writeLine(" {");
2754 fIndentation++;
2755 this->writeFields(structType->fields(), structType->fPosition);
2756 if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None) {
2757 this->writeLine("float2 " SKSL_RTFLIP_NAME ";");
2758 }
2759 fIndentation--;
2760 this->write("}");
2761 if (!intf.instanceName().empty()) {
2762 this->write(" ");
2763 this->write(intf.instanceName());
2764 if (intf.arraySize() > 0) {
2765 this->write("[");
2766 this->write(std::to_string(intf.arraySize()));
2767 this->write("]");
2768 }
2769 fInterfaceBlockNameMap.set(&intf.var()->type(), std::string(intf.instanceName()));
2770 } else {
2771 fInterfaceBlockNameMap.set(&intf.var()->type(),
2772 "_anonInterface" + std::to_string(fAnonInterfaceCount++));
2773 }
2774 this->writeLine(";");
2775 }
2776
writeFields(SkSpan<const Field> fields,Position parentPos)2777 void MetalCodeGenerator::writeFields(SkSpan<const Field> fields, Position parentPos) {
2778 MemoryLayout memoryLayout(MemoryLayout::Standard::kMetal);
2779 int currentOffset = 0;
2780 for (const Field& field : fields) {
2781 int fieldOffset = field.fLayout.fOffset;
2782 const Type* fieldType = field.fType;
2783 if (!memoryLayout.isSupported(*fieldType)) {
2784 fContext.fErrors->error(parentPos, "type '" + std::string(fieldType->name()) +
2785 "' is not permitted here");
2786 return;
2787 }
2788 if (fieldOffset != -1) {
2789 if (currentOffset > fieldOffset) {
2790 fContext.fErrors->error(field.fPosition,
2791 "offset of field '" + std::string(field.fName) +
2792 "' must be at least " + std::to_string(currentOffset));
2793 return;
2794 } else if (currentOffset < fieldOffset) {
2795 this->write("char pad");
2796 this->write(std::to_string(fPaddingCount++));
2797 this->write("[");
2798 this->write(std::to_string(fieldOffset - currentOffset));
2799 this->writeLine("];");
2800 currentOffset = fieldOffset;
2801 }
2802 int alignment = memoryLayout.alignment(*fieldType);
2803 if (fieldOffset % alignment) {
2804 fContext.fErrors->error(field.fPosition,
2805 "offset of field '" + std::string(field.fName) +
2806 "' must be a multiple of " + std::to_string(alignment));
2807 return;
2808 }
2809 }
2810 if (fieldType->isUnsizedArray()) {
2811 // An unsized array always appears as the last member of a storage block. We declare
2812 // it as a one-element array and allow dereferencing past the capacity.
2813 // TODO(armansito): This is because C++ does not support flexible array members like C99
2814 // does. This generally works but it can lead to UB as compilers are free to insert
2815 // padding past the first element of the array. An alternative approach is to declare
2816 // the struct without the unsized array member and replace variable references with a
2817 // buffer offset calculation based on sizeof().
2818 this->writeModifiers(field.fModifierFlags);
2819 this->writeType(fieldType->componentType());
2820 this->write(" ");
2821 this->writeName(field.fName);
2822 this->write("[1]");
2823 } else {
2824 size_t fieldSize = memoryLayout.size(*fieldType);
2825 if (fieldSize > static_cast<size_t>(std::numeric_limits<int>::max() - currentOffset)) {
2826 fContext.fErrors->error(parentPos, "field offset overflow");
2827 return;
2828 }
2829 currentOffset += fieldSize;
2830 this->writeModifiers(field.fModifierFlags);
2831 this->writeType(*fieldType);
2832 this->write(" ");
2833 this->writeName(field.fName);
2834 }
2835 this->writeLine(";");
2836 }
2837 }
2838
writeVarInitializer(const Variable & var,const Expression & value)2839 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
2840 this->writeExpression(value, Precedence::kExpression);
2841 }
2842
writeName(std::string_view name)2843 void MetalCodeGenerator::writeName(std::string_view name) {
2844 if (fReservedWords.contains(name)) {
2845 this->write("_"); // adding underscore before name to avoid conflict with reserved words
2846 }
2847 this->write(name);
2848 }
2849
writeVarDeclaration(const VarDeclaration & varDecl)2850 void MetalCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2851 this->writeModifiers(varDecl.var()->modifierFlags());
2852 this->writeType(varDecl.var()->type());
2853 this->write(" ");
2854 this->writeName(varDecl.var()->mangledName());
2855 if (varDecl.value()) {
2856 this->write(" = ");
2857 this->writeVarInitializer(*varDecl.var(), *varDecl.value());
2858 }
2859 this->write(";");
2860 }
2861
writeStatement(const Statement & s)2862 void MetalCodeGenerator::writeStatement(const Statement& s) {
2863 switch (s.kind()) {
2864 case Statement::Kind::kBlock:
2865 this->writeBlock(s.as<Block>());
2866 break;
2867 case Statement::Kind::kExpression:
2868 this->writeExpressionStatement(s.as<ExpressionStatement>());
2869 break;
2870 case Statement::Kind::kReturn:
2871 this->writeReturnStatement(s.as<ReturnStatement>());
2872 break;
2873 case Statement::Kind::kVarDeclaration:
2874 this->writeVarDeclaration(s.as<VarDeclaration>());
2875 break;
2876 case Statement::Kind::kIf:
2877 this->writeIfStatement(s.as<IfStatement>());
2878 break;
2879 case Statement::Kind::kFor:
2880 this->writeForStatement(s.as<ForStatement>());
2881 break;
2882 case Statement::Kind::kDo:
2883 this->writeDoStatement(s.as<DoStatement>());
2884 break;
2885 case Statement::Kind::kSwitch:
2886 this->writeSwitchStatement(s.as<SwitchStatement>());
2887 break;
2888 case Statement::Kind::kBreak:
2889 this->write("break;");
2890 break;
2891 case Statement::Kind::kContinue:
2892 this->write("continue;");
2893 break;
2894 case Statement::Kind::kDiscard:
2895 this->write("discard_fragment();");
2896 break;
2897 case Statement::Kind::kNop:
2898 this->write(";");
2899 break;
2900 default:
2901 SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
2902 break;
2903 }
2904 }
2905
writeBlock(const Block & b)2906 void MetalCodeGenerator::writeBlock(const Block& b) {
2907 // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
2908 // something here to make the code valid).
2909 bool isScope = b.isScope() || b.isEmpty();
2910 if (isScope) {
2911 this->writeLine("{");
2912 fIndentation++;
2913 }
2914 for (const std::unique_ptr<Statement>& stmt : b.children()) {
2915 if (!stmt->isEmpty()) {
2916 this->writeStatement(*stmt);
2917 this->finishLine();
2918 }
2919 }
2920 if (isScope) {
2921 fIndentation--;
2922 this->write("}");
2923 }
2924 }
2925
writeIfStatement(const IfStatement & stmt)2926 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
2927 this->write("if (");
2928 this->writeExpression(*stmt.test(), Precedence::kExpression);
2929 this->write(") ");
2930 this->writeStatement(*stmt.ifTrue());
2931 if (stmt.ifFalse()) {
2932 this->write(" else ");
2933 this->writeStatement(*stmt.ifFalse());
2934 }
2935 }
2936
writeForStatement(const ForStatement & f)2937 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
2938 // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
2939 if (!f.initializer() && f.test() && !f.next()) {
2940 this->write("while (");
2941 this->writeExpression(*f.test(), Precedence::kExpression);
2942 this->write(") ");
2943 this->writeStatement(*f.statement());
2944 return;
2945 }
2946
2947 this->write("for (");
2948 if (f.initializer() && !f.initializer()->isEmpty()) {
2949 this->writeStatement(*f.initializer());
2950 } else {
2951 this->write("; ");
2952 }
2953 if (f.test()) {
2954 this->writeExpression(*f.test(), Precedence::kExpression);
2955 }
2956 this->write("; ");
2957 if (f.next()) {
2958 this->writeExpression(*f.next(), Precedence::kExpression);
2959 }
2960 this->write(") ");
2961 this->writeStatement(*f.statement());
2962 }
2963
writeDoStatement(const DoStatement & d)2964 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
2965 this->write("do ");
2966 this->writeStatement(*d.statement());
2967 this->write(" while (");
2968 this->writeExpression(*d.test(), Precedence::kExpression);
2969 this->write(");");
2970 }
2971
writeExpressionStatement(const ExpressionStatement & s)2972 void MetalCodeGenerator::writeExpressionStatement(const ExpressionStatement& s) {
2973 if (fProgram.fConfig->fSettings.fOptimize && !Analysis::HasSideEffects(*s.expression())) {
2974 // Don't emit dead expressions.
2975 return;
2976 }
2977 this->writeExpression(*s.expression(), Precedence::kStatement);
2978 this->write(";");
2979 }
2980
writeSwitchStatement(const SwitchStatement & s)2981 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2982 this->write("switch (");
2983 this->writeExpression(*s.value(), Precedence::kExpression);
2984 this->writeLine(") {");
2985 fIndentation++;
2986 for (const std::unique_ptr<Statement>& stmt : s.cases()) {
2987 const SwitchCase& c = stmt->as<SwitchCase>();
2988 if (c.isDefault()) {
2989 this->writeLine("default:");
2990 } else {
2991 this->write("case ");
2992 this->write(std::to_string(c.value()));
2993 this->writeLine(":");
2994 }
2995 if (!c.statement()->isEmpty()) {
2996 fIndentation++;
2997 this->writeStatement(*c.statement());
2998 this->finishLine();
2999 fIndentation--;
3000 }
3001 }
3002 fIndentation--;
3003 this->write("}");
3004 }
3005
writeReturnStatementFromMain()3006 void MetalCodeGenerator::writeReturnStatementFromMain() {
3007 // main functions in Metal return a magic _out parameter that doesn't exist in SkSL.
3008 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) ||
3009 ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3010 this->write("return _out;");
3011 } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
3012 this->write("return;");
3013 } else {
3014 SkDEBUGFAIL("unsupported kind of program");
3015 }
3016 }
3017
writeReturnStatement(const ReturnStatement & r)3018 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
3019 if (fCurrentFunction && fCurrentFunction->isMain()) {
3020 if (r.expression()) {
3021 if (r.expression()->type().matches(*fContext.fTypes.fHalf4)) {
3022 this->write("_out.sk_FragColor = ");
3023 this->writeExpression(*r.expression(), Precedence::kExpression);
3024 this->writeLine(";");
3025 } else {
3026 fContext.fErrors->error(r.fPosition,
3027 "Metal does not support returning '" +
3028 r.expression()->type().description() + "' from main()");
3029 }
3030 }
3031 this->writeReturnStatementFromMain();
3032 return;
3033 }
3034
3035 this->write("return");
3036 if (r.expression()) {
3037 this->write(" ");
3038 this->writeExpression(*r.expression(), Precedence::kExpression);
3039 }
3040 this->write(";");
3041 }
3042
writeHeader()3043 void MetalCodeGenerator::writeHeader() {
3044 this->writeLine("#include <metal_stdlib>");
3045 this->writeLine("#include <simd/simd.h>");
3046 this->writeLine("#ifdef __clang__");
3047 this->writeLine("#pragma clang diagnostic ignored \"-Wall\"");
3048 this->writeLine("#endif");
3049 this->writeLine("using namespace metal;");
3050 }
3051
writeSampler2DPolyfill()3052 void MetalCodeGenerator::writeSampler2DPolyfill() {
3053 class : public GlobalStructVisitor {
3054 public:
3055 void visitSampler(const Type&, std::string_view) override {
3056 if (fWrotePolyfill) {
3057 return;
3058 }
3059 fWrotePolyfill = true;
3060
3061 std::string polyfill = SkSL::String::printf(R"(
3062 struct sampler2D {
3063 texture2d<half> tex;
3064 sampler smp;
3065 };
3066 half4 sample(sampler2D i, float2 p, float b=%g) { return i.tex.sample(i.smp, p, bias(b)); }
3067 half4 sample(sampler2D i, float3 p, float b=%g) { return i.tex.sample(i.smp, p.xy / p.z, bias(b)); }
3068 half4 sampleLod(sampler2D i, float2 p, float lod) { return i.tex.sample(i.smp, p, level(lod)); }
3069 half4 sampleLod(sampler2D i, float3 p, float lod) {
3070 return i.tex.sample(i.smp, p.xy / p.z, level(lod));
3071 }
3072 half4 sampleGrad(sampler2D i, float2 p, float2 dPdx, float2 dPdy) {
3073 return i.tex.sample(i.smp, p, gradient2d(dPdx, dPdy));
3074 }
3075
3076 )",
3077 fTextureBias,
3078 fTextureBias);
3079 fCodeGen->write(polyfill.c_str());
3080 }
3081
3082 MetalCodeGenerator* fCodeGen = nullptr;
3083 float fTextureBias = 0.0f;
3084 bool fWrotePolyfill = false;
3085 } visitor;
3086
3087 visitor.fCodeGen = this;
3088 visitor.fTextureBias = fProgram.fConfig->fSettings.fSharpenTextures ? kSharpenTexturesBias
3089 : 0.0f;
3090 this->visitGlobalStruct(&visitor);
3091 }
3092
writeUniformStruct()3093 void MetalCodeGenerator::writeUniformStruct() {
3094 for (const ProgramElement* e : fProgram.elements()) {
3095 if (e->is<GlobalVarDeclaration>()) {
3096 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
3097 const Variable& var = *decls.varDeclaration().var();
3098 if (var.modifierFlags().isUniform()) {
3099 SkASSERT(var.type().typeKind() != Type::TypeKind::kSampler &&
3100 var.type().typeKind() != Type::TypeKind::kTexture);
3101 int uniformSet = this->getUniformSet(var.layout());
3102 // Make sure that the program's uniform-set value is consistent throughout.
3103 if (-1 == fUniformBuffer) {
3104 this->write("struct Uniforms {\n");
3105 fUniformBuffer = uniformSet;
3106 } else if (uniformSet != fUniformBuffer) {
3107 fContext.fErrors->error(decls.fPosition,
3108 "Metal backend requires all uniforms to have the same "
3109 "'layout(set=...)'");
3110 }
3111 this->write(" ");
3112 this->writeType(var.type());
3113 this->write(" ");
3114 this->writeName(var.mangledName());
3115 this->write(";\n");
3116 }
3117 }
3118 }
3119 if (-1 != fUniformBuffer) {
3120 this->write("};\n");
3121 }
3122 }
3123
writeInterpolatedAttributes(const Variable & var)3124 void MetalCodeGenerator::writeInterpolatedAttributes(const Variable& var) {
3125 SkASSERT((is_output(var) && ProgramConfig::IsVertex(fProgram.fConfig->fKind)) ||
3126 (is_input(var) && ProgramConfig::IsFragment(fProgram.fConfig->fKind)));
3127 // Always include the location
3128 this->write(" [[user(locn");
3129 this->write(std::to_string(var.layout().fLocation));
3130 this->write(")");
3131
3132 if (var.modifierFlags().isFlat()) {
3133 this->write(" flat");
3134 } else if (var.modifierFlags().isNoPerspective()) {
3135 this->write(" center_no_perspective");
3136 } // else default behavior is center_perspective
3137
3138 this->write("]]");
3139 }
3140
writeInputStruct()3141 void MetalCodeGenerator::writeInputStruct() {
3142 this->write("struct Inputs {\n");
3143 for (const ProgramElement* e : fProgram.elements()) {
3144 if (e->is<GlobalVarDeclaration>()) {
3145 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
3146 const Variable& var = *decls.varDeclaration().var();
3147 if (is_input(var)) {
3148 this->write(" ");
3149 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
3150 needs_address_space(var.type(), var.modifierFlags())) {
3151 // TODO: address space support
3152 this->write("device ");
3153 }
3154 this->writeType(var.type());
3155 if (pass_by_reference(var.type(), var.modifierFlags())) {
3156 this->write("&");
3157 }
3158 this->write(" ");
3159 this->writeName(var.mangledName());
3160 if (-1 != var.layout().fLocation) {
3161 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3162 this->write(" [[attribute(" + std::to_string(var.layout().fLocation) +
3163 ")]]");
3164 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3165 // Write attributes for the fragment input that are consistent with
3166 // what's annotated on the vertex output.
3167 this->writeInterpolatedAttributes(var);
3168 }
3169 }
3170 this->write(";\n");
3171 }
3172 }
3173 }
3174 this->write("};\n");
3175 }
3176
writeOutputStruct()3177 void MetalCodeGenerator::writeOutputStruct() {
3178 this->write("struct Outputs {\n");
3179 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3180 this->write(" float4 sk_Position [[position]];\n");
3181 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3182 this->write(" half4 sk_FragColor [[color(0)]];\n");
3183 if (fProgram.fInterface.fOutputSecondaryColor) {
3184 this->write(" half4 sk_SecondaryFragColor [[color(0), index(1)]];\n");
3185 }
3186 }
3187 for (const ProgramElement* e : fProgram.elements()) {
3188 if (e->is<GlobalVarDeclaration>()) {
3189 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
3190 const Variable& var = *decls.varDeclaration().var();
3191 if (var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3192 this->write(" uint sk_SampleMask [[sample_mask]];\n");
3193 continue;
3194 }
3195 if (is_output(var)) {
3196 this->write(" ");
3197 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
3198 needs_address_space(var.type(), var.modifierFlags())) {
3199 // TODO: address space support
3200 this->write("device ");
3201 }
3202 this->writeType(var.type());
3203 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
3204 pass_by_reference(var.type(), var.modifierFlags())) {
3205 this->write("&");
3206 }
3207 this->write(" ");
3208 this->writeName(var.mangledName());
3209
3210 int location = var.layout().fLocation;
3211 if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind) && location < 0 &&
3212 var.type().typeKind() != Type::TypeKind::kTexture) {
3213 fContext.fErrors->error(var.fPosition,
3214 "Metal out variables must have 'layout(location=...)'");
3215 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3216 // Write attributes for the vertex output that are consistent with what's
3217 // annotated on the fragment input.
3218 this->writeInterpolatedAttributes(var);
3219 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3220 this->write(" [[color(" + std::to_string(location) + ")");
3221 int colorIndex = var.layout().fIndex;
3222 if (colorIndex) {
3223 this->write(", index(" + std::to_string(colorIndex) + ")");
3224 }
3225 this->write("]]");
3226 }
3227 this->write(";\n");
3228 }
3229 }
3230 }
3231 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3232 this->write(" float sk_PointSize [[point_size]];\n");
3233 }
3234 this->write("};\n");
3235 }
3236
writeInterfaceBlocks()3237 void MetalCodeGenerator::writeInterfaceBlocks() {
3238 bool wroteInterfaceBlock = false;
3239 for (const ProgramElement* e : fProgram.elements()) {
3240 if (e->is<InterfaceBlock>()) {
3241 this->writeInterfaceBlock(e->as<InterfaceBlock>());
3242 wroteInterfaceBlock = true;
3243 }
3244 }
3245 if (!wroteInterfaceBlock &&
3246 fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None) {
3247 this->writeLine("struct sksl_synthetic_uniforms {");
3248 this->writeLine(" float2 " SKSL_RTFLIP_NAME ";");
3249 this->writeLine("};");
3250 }
3251 }
3252
writeStructDefinitions()3253 void MetalCodeGenerator::writeStructDefinitions() {
3254 for (const ProgramElement* e : fProgram.elements()) {
3255 if (e->is<StructDefinition>()) {
3256 this->writeStructDefinition(e->as<StructDefinition>());
3257 }
3258 }
3259 }
3260
writeConstantVariables()3261 void MetalCodeGenerator::writeConstantVariables() {
3262 class : public GlobalStructVisitor {
3263 public:
3264 void visitConstantVariable(const VarDeclaration& decl) override {
3265 fCodeGen->write("constant ");
3266 fCodeGen->writeVarDeclaration(decl);
3267 fCodeGen->finishLine();
3268 }
3269
3270 MetalCodeGenerator* fCodeGen = nullptr;
3271 } visitor;
3272
3273 visitor.fCodeGen = this;
3274 this->visitGlobalStruct(&visitor);
3275 }
3276
visitGlobalStruct(GlobalStructVisitor * visitor)3277 void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) {
3278 for (const ProgramElement* element : fProgram.elements()) {
3279 if (element->is<InterfaceBlock>()) {
3280 const auto* ib = &element->as<InterfaceBlock>();
3281 if (ib->typeName() != "sk_PerVertex") {
3282 visitor->visitInterfaceBlock(*ib, fInterfaceBlockNameMap[&ib->var()->type()]);
3283 }
3284 continue;
3285 }
3286 if (!element->is<GlobalVarDeclaration>()) {
3287 continue;
3288 }
3289 const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
3290 const VarDeclaration& decl = global.varDeclaration();
3291 const Variable& var = *decl.var();
3292 if (decl.baseType().typeKind() == Type::TypeKind::kSampler) {
3293 visitor->visitSampler(var.type(), var.mangledName());
3294 continue;
3295 }
3296 if (decl.baseType().typeKind() == Type::TypeKind::kTexture) {
3297 visitor->visitTexture(var.type(), var.mangledName());
3298 continue;
3299 }
3300 if (!(var.modifierFlags() & ~ModifierFlag::kConst) && var.layout().fBuiltin == -1) {
3301 if (is_in_globals(var)) {
3302 // Visit a regular global variable.
3303 visitor->visitNonconstantVariable(var, decl.value().get());
3304 } else {
3305 // Visit a constant-expression variable.
3306 SkASSERT(var.modifierFlags().isConst());
3307 visitor->visitConstantVariable(decl);
3308 }
3309 }
3310 }
3311 }
3312
writeGlobalStruct()3313 void MetalCodeGenerator::writeGlobalStruct() {
3314 class : public GlobalStructVisitor {
3315 public:
3316 void visitInterfaceBlock(const InterfaceBlock& block,
3317 std::string_view blockName) override {
3318 this->addElement();
3319 fCodeGen->write(" ");
3320 if (is_readonly(block)) {
3321 fCodeGen->write("const ");
3322 }
3323 fCodeGen->write(is_buffer(block) ? "device " : "constant ");
3324 fCodeGen->write(block.typeName());
3325 fCodeGen->write("* ");
3326 fCodeGen->writeName(blockName);
3327 fCodeGen->write(";\n");
3328 }
3329 void visitTexture(const Type& type, std::string_view name) override {
3330 this->addElement();
3331 fCodeGen->write(" ");
3332 fCodeGen->writeType(type);
3333 fCodeGen->write(" ");
3334 fCodeGen->writeName(name);
3335 fCodeGen->write(";\n");
3336 }
3337 void visitSampler(const Type&, std::string_view name) override {
3338 this->addElement();
3339 fCodeGen->write(" sampler2D ");
3340 fCodeGen->writeName(name);
3341 fCodeGen->write(";\n");
3342 }
3343 void visitConstantVariable(const VarDeclaration& decl) override {
3344 // Constants aren't added to the global struct.
3345 }
3346 void visitNonconstantVariable(const Variable& var, const Expression* value) override {
3347 this->addElement();
3348 fCodeGen->write(" ");
3349 fCodeGen->writeModifiers(var.modifierFlags());
3350 fCodeGen->writeType(var.type());
3351 fCodeGen->write(" ");
3352 fCodeGen->writeName(var.mangledName());
3353 fCodeGen->write(";\n");
3354 }
3355 void addElement() {
3356 if (fFirst) {
3357 fCodeGen->write("struct Globals {\n");
3358 fFirst = false;
3359 }
3360 }
3361 void finish() {
3362 if (!fFirst) {
3363 fCodeGen->writeLine("};");
3364 fFirst = true;
3365 }
3366 }
3367
3368 MetalCodeGenerator* fCodeGen = nullptr;
3369 bool fFirst = true;
3370 } visitor;
3371
3372 visitor.fCodeGen = this;
3373 this->visitGlobalStruct(&visitor);
3374 visitor.finish();
3375 }
3376
writeGlobalInit()3377 void MetalCodeGenerator::writeGlobalInit() {
3378 class : public GlobalStructVisitor {
3379 public:
3380 void visitInterfaceBlock(const InterfaceBlock& blockType,
3381 std::string_view blockName) override {
3382 this->addElement();
3383 fCodeGen->write("&");
3384 fCodeGen->writeName(blockName);
3385 }
3386 void visitTexture(const Type&, std::string_view name) override {
3387 this->addElement();
3388 fCodeGen->writeName(name);
3389 }
3390 void visitSampler(const Type&, std::string_view name) override {
3391 this->addElement();
3392 fCodeGen->write("{");
3393 fCodeGen->writeName(name);
3394 fCodeGen->write(kTextureSuffix);
3395 fCodeGen->write(", ");
3396 fCodeGen->writeName(name);
3397 fCodeGen->write(kSamplerSuffix);
3398 fCodeGen->write("}");
3399 }
3400 void visitConstantVariable(const VarDeclaration& decl) override {
3401 // Constant-expression variables aren't put in the global struct.
3402 }
3403 void visitNonconstantVariable(const Variable& var, const Expression* value) override {
3404 this->addElement();
3405 if (value) {
3406 fCodeGen->writeVarInitializer(var, *value);
3407 } else {
3408 fCodeGen->write("{}");
3409 }
3410 }
3411 void addElement() {
3412 if (fFirst) {
3413 fCodeGen->write("Globals _globals{");
3414 fFirst = false;
3415 } else {
3416 fCodeGen->write(", ");
3417 }
3418 }
3419 void finish() {
3420 if (!fFirst) {
3421 fCodeGen->writeLine("};");
3422 fCodeGen->writeLine("(void)_globals;");
3423 }
3424 }
3425 MetalCodeGenerator* fCodeGen = nullptr;
3426 bool fFirst = true;
3427 } visitor;
3428
3429 visitor.fCodeGen = this;
3430 this->visitGlobalStruct(&visitor);
3431 visitor.finish();
3432 }
3433
visitThreadgroupStruct(ThreadgroupStructVisitor * visitor)3434 void MetalCodeGenerator::visitThreadgroupStruct(ThreadgroupStructVisitor* visitor) {
3435 for (const ProgramElement* element : fProgram.elements()) {
3436 if (!element->is<GlobalVarDeclaration>()) {
3437 continue;
3438 }
3439 const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
3440 const VarDeclaration& decl = global.varDeclaration();
3441 const Variable& var = *decl.var();
3442 if (var.modifierFlags().isWorkgroup()) {
3443 SkASSERT(!decl.value());
3444 SkASSERT(!var.modifierFlags().isConst());
3445 visitor->visitNonconstantVariable(var);
3446 }
3447 }
3448 }
3449
writeThreadgroupStruct()3450 void MetalCodeGenerator::writeThreadgroupStruct() {
3451 class : public ThreadgroupStructVisitor {
3452 public:
3453 void visitNonconstantVariable(const Variable& var) override {
3454 this->addElement();
3455 fCodeGen->write(" ");
3456 fCodeGen->writeModifiers(var.modifierFlags());
3457 fCodeGen->writeType(var.type());
3458 fCodeGen->write(" ");
3459 fCodeGen->writeName(var.mangledName());
3460 fCodeGen->write(";\n");
3461 }
3462 void addElement() {
3463 if (fFirst) {
3464 fCodeGen->write("struct Threadgroups {\n");
3465 fFirst = false;
3466 }
3467 }
3468 void finish() {
3469 if (!fFirst) {
3470 fCodeGen->writeLine("};");
3471 fFirst = true;
3472 }
3473 }
3474
3475 MetalCodeGenerator* fCodeGen = nullptr;
3476 bool fFirst = true;
3477 } visitor;
3478
3479 visitor.fCodeGen = this;
3480 this->visitThreadgroupStruct(&visitor);
3481 visitor.finish();
3482 }
3483
writeThreadgroupInit()3484 void MetalCodeGenerator::writeThreadgroupInit() {
3485 class : public ThreadgroupStructVisitor {
3486 public:
3487 void visitNonconstantVariable(const Variable& var) override {
3488 this->addElement();
3489 fCodeGen->write("{}");
3490 }
3491 void addElement() {
3492 if (fFirst) {
3493 fCodeGen->write("threadgroup Threadgroups _threadgroups{");
3494 fFirst = false;
3495 } else {
3496 fCodeGen->write(", ");
3497 }
3498 }
3499 void finish() {
3500 if (!fFirst) {
3501 fCodeGen->writeLine("};");
3502 fCodeGen->writeLine("(void)_threadgroups;");
3503 }
3504 }
3505 MetalCodeGenerator* fCodeGen = nullptr;
3506 bool fFirst = true;
3507 } visitor;
3508
3509 visitor.fCodeGen = this;
3510 this->visitThreadgroupStruct(&visitor);
3511 visitor.finish();
3512 }
3513
writeProgramElement(const ProgramElement & e)3514 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
3515 switch (e.kind()) {
3516 case ProgramElement::Kind::kExtension:
3517 break;
3518 case ProgramElement::Kind::kGlobalVar:
3519 break;
3520 case ProgramElement::Kind::kInterfaceBlock:
3521 // Handled in writeInterfaceBlocks; do nothing.
3522 break;
3523 case ProgramElement::Kind::kStructDefinition:
3524 // Handled in writeStructDefinitions; do nothing.
3525 break;
3526 case ProgramElement::Kind::kFunction:
3527 this->writeFunction(e.as<FunctionDefinition>());
3528 break;
3529 case ProgramElement::Kind::kFunctionPrototype:
3530 this->writeFunctionPrototype(e.as<FunctionPrototype>());
3531 break;
3532 case ProgramElement::Kind::kModifiers:
3533 // Not necessary in Metal; do nothing.
3534 break;
3535 default:
3536 SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
3537 break;
3538 }
3539 }
3540
requirements(const Statement * s)3541 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) {
3542 class RequirementsVisitor : public ProgramVisitor {
3543 public:
3544 using ProgramVisitor::visitStatement;
3545
3546 bool visitExpression(const Expression& e) override {
3547 switch (e.kind()) {
3548 case Expression::Kind::kFunctionCall: {
3549 const FunctionCall& f = e.as<FunctionCall>();
3550 fRequirements |= fCodeGen->requirements(f.function());
3551 break;
3552 }
3553 case Expression::Kind::kFieldAccess: {
3554 const FieldAccess& f = e.as<FieldAccess>();
3555 if (f.ownerKind() == FieldAccess::OwnerKind::kAnonymousInterfaceBlock) {
3556 fRequirements |= kGlobals_Requirement;
3557 return false; // don't recurse into the base variable
3558 }
3559 break;
3560 }
3561 case Expression::Kind::kVariableReference: {
3562 const Variable& var = *e.as<VariableReference>().variable();
3563
3564 if (var.layout().fBuiltin == SK_FRAGCOORD_BUILTIN) {
3565 fRequirements |= kGlobals_Requirement | kFragCoord_Requirement;
3566 } else if (var.layout().fBuiltin == SK_SAMPLEMASKIN_BUILTIN) {
3567 fRequirements |= kSampleMaskIn_Requirement;
3568 } else if (var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3569 fRequirements |= kOutputs_Requirement;
3570 } else if (var.layout().fBuiltin == SK_VERTEXID_BUILTIN) {
3571 fRequirements |= kVertexID_Requirement;
3572 } else if (var.layout().fBuiltin == SK_INSTANCEID_BUILTIN) {
3573 fRequirements |= kInstanceID_Requirement;
3574 } else if (var.storage() == Variable::Storage::kGlobal) {
3575 if (is_input(var)) {
3576 fRequirements |= kInputs_Requirement;
3577 } else if (is_output(var)) {
3578 fRequirements |= kOutputs_Requirement;
3579 } else if (is_uniforms(var)) {
3580 fRequirements |= kUniforms_Requirement;
3581 } else if (is_threadgroup(var)) {
3582 fRequirements |= kThreadgroups_Requirement;
3583 } else if (is_in_globals(var)) {
3584 fRequirements |= kGlobals_Requirement;
3585 }
3586 }
3587 break;
3588 }
3589 default:
3590 break;
3591 }
3592 return ProgramVisitor::visitExpression(e);
3593 }
3594
3595 MetalCodeGenerator* fCodeGen;
3596 Requirements fRequirements = kNo_Requirements;
3597 };
3598
3599 RequirementsVisitor visitor;
3600 if (s) {
3601 visitor.fCodeGen = this;
3602 visitor.visitStatement(*s);
3603 }
3604 return visitor.fRequirements;
3605 }
3606
requirements(const FunctionDeclaration & f)3607 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
3608 Requirements* found = fRequirements.find(&f);
3609 if (!found) {
3610 fRequirements.set(&f, kNo_Requirements);
3611 for (const ProgramElement* e : fProgram.elements()) {
3612 if (e->is<FunctionDefinition>()) {
3613 const FunctionDefinition& def = e->as<FunctionDefinition>();
3614 if (&def.declaration() == &f) {
3615 Requirements reqs = this->requirements(def.body().get());
3616 fRequirements.set(&f, reqs);
3617 return reqs;
3618 }
3619 }
3620 }
3621
3622 // We never found a definition for this declared function, but it's legal to prototype a
3623 // function without ever giving a definition, as long as you don't call it.
3624 return kNo_Requirements;
3625 }
3626 return *found;
3627 }
3628
generateCode()3629 bool MetalCodeGenerator::generateCode() {
3630 StringStream header;
3631 {
3632 AutoOutputStream outputToHeader(this, &header, &fIndentation);
3633 this->writeHeader();
3634 this->writeConstantVariables();
3635 this->writeSampler2DPolyfill();
3636 this->writeStructDefinitions();
3637 this->writeUniformStruct();
3638 this->writeInputStruct();
3639 if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
3640 this->writeOutputStruct();
3641 }
3642 this->writeInterfaceBlocks();
3643 this->writeGlobalStruct();
3644 this->writeThreadgroupStruct();
3645
3646 // Emit prototypes for every built-in function; these aren't always added in perfect order.
3647 for (const ProgramElement* e : fProgram.fSharedElements) {
3648 if (e->is<FunctionDefinition>()) {
3649 this->writeFunctionDeclaration(e->as<FunctionDefinition>().declaration());
3650 this->writeLine(";");
3651 }
3652 }
3653 }
3654 StringStream body;
3655 {
3656 AutoOutputStream outputToBody(this, &body, &fIndentation);
3657
3658 for (const ProgramElement* e : fProgram.elements()) {
3659 this->writeProgramElement(*e);
3660 }
3661 }
3662 write_stringstream(header, *fOut);
3663 write_stringstream(fExtraFunctionPrototypes, *fOut);
3664 write_stringstream(fExtraFunctions, *fOut);
3665 write_stringstream(body, *fOut);
3666 return fContext.fErrors->errorCount() == 0;
3667 }
3668
ToMetal(Program & program,const ShaderCaps * caps,OutputStream & out,PrettyPrint pp)3669 bool ToMetal(Program& program, const ShaderCaps* caps, OutputStream& out, PrettyPrint pp) {
3670 TRACE_EVENT0("skia.shaders", "SkSL::ToMetal");
3671 SkASSERT(caps != nullptr);
3672
3673 program.fContext->fErrors->setSource(*program.fSource);
3674 MetalCodeGenerator cg(program.fContext.get(), caps, &program, &out, pp);
3675 bool result = cg.generateCode();
3676 program.fContext->fErrors->setSource(std::string_view());
3677
3678 return result;
3679 }
3680
ToMetal(Program & program,const ShaderCaps * caps,OutputStream & out)3681 bool ToMetal(Program& program, const ShaderCaps* caps, OutputStream& out) {
3682 #if defined(SK_DEBUG)
3683 constexpr PrettyPrint defaultPrintOpts = PrettyPrint::kYes;
3684 #else
3685 constexpr PrettyPrint defaultPrintOpts = PrettyPrint::kNo;
3686 #endif
3687 return ToMetal(program, caps, out, defaultPrintOpts);
3688 }
3689
ToMetal(Program & program,const ShaderCaps * caps,std::string * out)3690 bool ToMetal(Program& program, const ShaderCaps* caps, std::string* out) {
3691 StringStream buffer;
3692 if (!ToMetal(program, caps, buffer)) {
3693 return false;
3694 }
3695 *out = buffer.str();
3696 return true;
3697 }
3698
3699 } // namespace SkSL
3700