xref: /aosp_15_r20/external/skia/src/sksl/codegen/SkSLMetalCodeGenerator.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 #include "src/sksl/codegen/SkSLMetalCodeGenerator.h"
8 
9 #include "include/core/SkSpan.h"
10 #include "include/core/SkTypes.h"
11 #include "include/private/base/SkTArray.h"
12 #include "include/private/base/SkTo.h"
13 #include "src/base/SkEnumBitMask.h"
14 #include "src/base/SkScopeExit.h"
15 #include "src/core/SkTHash.h"
16 #include "src/core/SkTraceEvent.h"
17 #include "src/sksl/SkSLAnalysis.h"
18 #include "src/sksl/SkSLBuiltinTypes.h"
19 #include "src/sksl/SkSLCompiler.h"
20 #include "src/sksl/SkSLContext.h"
21 #include "src/sksl/SkSLDefines.h"
22 #include "src/sksl/SkSLErrorReporter.h"
23 #include "src/sksl/SkSLIntrinsicList.h"
24 #include "src/sksl/SkSLMemoryLayout.h"
25 #include "src/sksl/SkSLOperator.h"
26 #include "src/sksl/SkSLOutputStream.h"
27 #include "src/sksl/SkSLPosition.h"
28 #include "src/sksl/SkSLProgramSettings.h"
29 #include "src/sksl/SkSLString.h"
30 #include "src/sksl/SkSLStringStream.h"
31 #include "src/sksl/SkSLUtil.h"
32 #include "src/sksl/analysis/SkSLProgramVisitor.h"
33 #include "src/sksl/codegen/SkSLCodeGenTypes.h"
34 #include "src/sksl/codegen/SkSLCodeGenerator.h"
35 #include "src/sksl/ir/SkSLBinaryExpression.h"
36 #include "src/sksl/ir/SkSLBlock.h"
37 #include "src/sksl/ir/SkSLConstructor.h"
38 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
39 #include "src/sksl/ir/SkSLConstructorCompound.h"
40 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
41 #include "src/sksl/ir/SkSLDoStatement.h"
42 #include "src/sksl/ir/SkSLExpression.h"
43 #include "src/sksl/ir/SkSLExpressionStatement.h"
44 #include "src/sksl/ir/SkSLExtension.h"
45 #include "src/sksl/ir/SkSLFieldAccess.h"
46 #include "src/sksl/ir/SkSLForStatement.h"
47 #include "src/sksl/ir/SkSLFunctionCall.h"
48 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
49 #include "src/sksl/ir/SkSLFunctionDefinition.h"
50 #include "src/sksl/ir/SkSLFunctionPrototype.h"
51 #include "src/sksl/ir/SkSLIRHelpers.h"
52 #include "src/sksl/ir/SkSLIRNode.h"
53 #include "src/sksl/ir/SkSLIfStatement.h"
54 #include "src/sksl/ir/SkSLIndexExpression.h"
55 #include "src/sksl/ir/SkSLInterfaceBlock.h"
56 #include "src/sksl/ir/SkSLLayout.h"
57 #include "src/sksl/ir/SkSLLiteral.h"
58 #include "src/sksl/ir/SkSLModifierFlags.h"
59 #include "src/sksl/ir/SkSLNop.h"
60 #include "src/sksl/ir/SkSLPostfixExpression.h"
61 #include "src/sksl/ir/SkSLPrefixExpression.h"
62 #include "src/sksl/ir/SkSLProgram.h"
63 #include "src/sksl/ir/SkSLProgramElement.h"
64 #include "src/sksl/ir/SkSLReturnStatement.h"
65 #include "src/sksl/ir/SkSLSetting.h"
66 #include "src/sksl/ir/SkSLStatement.h"
67 #include "src/sksl/ir/SkSLStructDefinition.h"
68 #include "src/sksl/ir/SkSLSwitchCase.h"
69 #include "src/sksl/ir/SkSLSwitchStatement.h"
70 #include "src/sksl/ir/SkSLSwizzle.h"
71 #include "src/sksl/ir/SkSLTernaryExpression.h"
72 #include "src/sksl/ir/SkSLType.h"
73 #include "src/sksl/ir/SkSLVarDeclarations.h"
74 #include "src/sksl/ir/SkSLVariable.h"
75 #include "src/sksl/ir/SkSLVariableReference.h"
76 #include "src/sksl/spirv.h"
77 
78 #include <algorithm>
79 #include <cstddef>
80 #include <cstdint>
81 #include <functional>
82 #include <initializer_list>
83 #include <limits>
84 #include <memory>
85 #include <string>
86 #include <string_view>
87 #include <utility>
88 #include <vector>
89 
90 using namespace skia_private;
91 
92 namespace SkSL {
93 
94 class MetalCodeGenerator : public CodeGenerator {
95 public:
MetalCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out,PrettyPrint pp)96     MetalCodeGenerator(const Context* context,
97                        const ShaderCaps* caps,
98                        const Program* program,
99                        OutputStream* out,
100                        PrettyPrint pp)
101             : CodeGenerator(context, caps, program, out)
102             , fReservedWords({"atan2", "rsqrt", "rint", "dfdx", "dfdy", "vertex", "fragment"})
103             , fLineEnding("\n")
104             , fPrettyPrint(pp) {}
105 
106     bool generateCode() override;
107 
108 protected:
109     using Precedence = OperatorPrecedence;
110 
111     using Requirements =  int;
112     static constexpr Requirements kNo_Requirements          = 0;
113     static constexpr Requirements kInputs_Requirement       = 1 << 0;
114     static constexpr Requirements kOutputs_Requirement      = 1 << 1;
115     static constexpr Requirements kUniforms_Requirement     = 1 << 2;
116     static constexpr Requirements kGlobals_Requirement      = 1 << 3;
117     static constexpr Requirements kFragCoord_Requirement    = 1 << 4;
118     static constexpr Requirements kSampleMaskIn_Requirement = 1 << 5;
119     static constexpr Requirements kVertexID_Requirement     = 1 << 6;
120     static constexpr Requirements kInstanceID_Requirement   = 1 << 7;
121     static constexpr Requirements kThreadgroups_Requirement = 1 << 8;
122 
123     class GlobalStructVisitor;
124     void visitGlobalStruct(GlobalStructVisitor* visitor);
125 
126     class ThreadgroupStructVisitor;
127     void visitThreadgroupStruct(ThreadgroupStructVisitor* visitor);
128 
129     void write(std::string_view s);
130 
131     void writeLine(std::string_view s = std::string_view());
132 
133     void finishLine();
134 
135     void writeHeader();
136 
137     void writeSampler2DPolyfill();
138 
139     void writeUniformStruct();
140 
141     void writeInterpolatedAttributes(const Variable& var);
142 
143     void writeInputStruct();
144 
145     void writeOutputStruct();
146 
147     void writeInterfaceBlocks();
148 
149     void writeStructDefinitions();
150 
151     void writeConstantVariables();
152 
153     void writeFields(SkSpan<const Field> fields, Position pos);
154 
155     int size(const Type* type, bool isPacked) const;
156 
157     int alignment(const Type* type, bool isPacked) const;
158 
159     void writeGlobalStruct();
160 
161     void writeGlobalInit();
162 
163     void writeThreadgroupStruct();
164 
165     void writeThreadgroupInit();
166 
167     void writePrecisionModifier();
168 
169     std::string typeName(const Type& type);
170 
171     void writeStructDefinition(const StructDefinition& s);
172 
173     void writeType(const Type& type);
174 
175     void writeExtension(const Extension& ext);
176 
177     void writeInterfaceBlock(const InterfaceBlock& intf);
178 
179     void writeFunctionRequirementParams(const FunctionDeclaration& f,
180                                         const char*& separator);
181 
182     void writeFunctionRequirementArgs(const FunctionDeclaration& f, const char*& separator);
183 
184     bool writeFunctionDeclaration(const FunctionDeclaration& f);
185 
186     void writeFunction(const FunctionDefinition& f);
187 
188     void writeFunctionPrototype(const FunctionPrototype& f);
189 
190     void writeLayout(const Layout& layout);
191 
192     void writeModifiers(ModifierFlags flags);
193 
194     void writeVarInitializer(const Variable& var, const Expression& value);
195 
196     void writeName(std::string_view name);
197 
198     void writeVarDeclaration(const VarDeclaration& decl);
199 
200     void writeFragCoord();
201 
202     void writeVariableReference(const VariableReference& ref);
203 
204     void writeExpression(const Expression& expr, Precedence parentPrecedence);
205 
206     void writeMinAbsHack(Expression& absExpr, Expression& otherExpr);
207 
208     std::string getInversePolyfill(const ExpressionArray& arguments);
209 
210     std::string getBitcastIntrinsic(const Type& outType);
211 
212     std::string getTempVariable(const Type& varType);
213 
214     void writeFunctionCall(const FunctionCall& c);
215 
216     bool matrixConstructHelperIsNeeded(const ConstructorCompound& c);
217     std::string getMatrixConstructHelper(const AnyConstructor& c);
218     void assembleMatrixFromMatrix(const Type& sourceMatrix, int columns, int rows);
219     void assembleMatrixFromExpressions(const AnyConstructor& ctor, int columns, int rows);
220 
221     void writeMatrixCompMult();
222 
223     void writeOuterProduct();
224 
225     void writeMatrixTimesEqualHelper(const Type& left, const Type& right, const Type& result);
226 
227     void writeMatrixDivisionHelpers(const Type& type);
228 
229     void writeMatrixEqualityHelpers(const Type& left, const Type& right);
230 
231     std::string getVectorFromMat2x2ConstructorHelper(const Type& matrixType);
232 
233     void writeArrayEqualityHelpers(const Type& type);
234 
235     void writeStructEqualityHelpers(const Type& type);
236 
237     void writeEqualityHelpers(const Type& leftType, const Type& rightType);
238 
239     void writeArgumentList(const ExpressionArray& arguments);
240 
241     void writeSimpleIntrinsic(const FunctionCall& c);
242 
243     bool writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind);
244 
245     void writeConstructorCompound(const ConstructorCompound& c, Precedence parentPrecedence);
246 
247     void writeConstructorCompoundVector(const ConstructorCompound& c, Precedence parentPrecedence);
248 
249     void writeConstructorCompoundMatrix(const ConstructorCompound& c, Precedence parentPrecedence);
250 
251     void writeConstructorMatrixResize(const ConstructorMatrixResize& c,
252                                       Precedence parentPrecedence);
253 
254     void writeAnyConstructor(const AnyConstructor& c,
255                              const char* leftBracket,
256                              const char* rightBracket,
257                              Precedence parentPrecedence);
258 
259     void writeCastConstructor(const AnyConstructor& c,
260                               const char* leftBracket,
261                               const char* rightBracket,
262                               Precedence parentPrecedence);
263 
264     void writeConstructorArrayCast(const ConstructorArrayCast& c, Precedence parentPrecedence);
265 
266     void writeFieldAccess(const FieldAccess& f);
267 
268     void writeSwizzle(const Swizzle& swizzle);
269 
270     // Returns `floatCxR(1.0, 1.0, 1.0, 1.0, ...)`.
271     std::string splatMatrixOf1(const Type& type);
272 
273     // Splats a scalar expression across a matrix of arbitrary size.
274     void writeNumberAsMatrix(const Expression& expr, const Type& matrixType);
275 
276     void writeBinaryExpressionElement(const Expression& expr,
277                                       Operator op,
278                                       const Expression& other,
279                                       Precedence precedence);
280 
281     void writeBinaryExpression(const BinaryExpression& b, Precedence parentPrecedence);
282 
283     void writeTernaryExpression(const TernaryExpression& t, Precedence parentPrecedence);
284 
285     void writeIndexExpression(const IndexExpression& expr);
286 
287     void writeIndexInnerExpression(const Expression& expr);
288 
289     void writePrefixExpression(const PrefixExpression& p, Precedence parentPrecedence);
290 
291     void writePostfixExpression(const PostfixExpression& p, Precedence parentPrecedence);
292 
293     void writeLiteral(const Literal& f);
294 
295     void writeStatement(const Statement& s);
296 
297     void writeStatements(const StatementArray& statements);
298 
299     void writeBlock(const Block& b);
300 
301     void writeIfStatement(const IfStatement& stmt);
302 
303     void writeForStatement(const ForStatement& f);
304 
305     void writeDoStatement(const DoStatement& d);
306 
307     void writeExpressionStatement(const ExpressionStatement& s);
308 
309     void writeSwitchStatement(const SwitchStatement& s);
310 
311     void writeReturnStatementFromMain();
312 
313     void writeReturnStatement(const ReturnStatement& r);
314 
315     void writeProgramElement(const ProgramElement& e);
316 
317     Requirements requirements(const FunctionDeclaration& f);
318 
319     Requirements requirements(const Statement* s);
320 
321     // For compute shader main functions, writes and initializes the _in and _out structs (the
322     // instances, not the types themselves)
323     void writeComputeMainInputs();
324 
325     int getUniformBinding(const Layout& layout);
326 
327     int getUniformSet(const Layout& layout);
328 
329     void writeWithIndexSubstitution(const std::function<void()>& fn);
330 
331     skia_private::THashSet<std::string_view> fReservedWords;
332     skia_private::THashMap<const Type*, std::string> fInterfaceBlockNameMap;
333     int fAnonInterfaceCount = 0;
334     int fPaddingCount = 0;
335     const char* fLineEnding;
336     std::string fFunctionHeader;
337     StringStream fExtraFunctions;
338     StringStream fExtraFunctionPrototypes;
339     int fVarCount = 0;
340     int fIndentation = 0;
341     bool fAtLineStart = false;
342     // true if we have run into usages of dFdx / dFdy
343     bool fFoundDerivatives = false;
344     skia_private::THashMap<const FunctionDeclaration*, Requirements> fRequirements;
345     skia_private::THashSet<std::string> fHelpers;
346     int fUniformBuffer = -1;
347     std::string fRTFlipName;
348     const FunctionDeclaration* fCurrentFunction = nullptr;
349     int fSwizzleHelperCount = 0;
350     static constexpr char kTextureSuffix[] = "_Tex";
351     static constexpr char kSamplerSuffix[] = "_Smplr";
352 
353     // If we might use an index expression more than once, we need to capture the result in a
354     // temporary variable to avoid double-evaluation. This should generally only occur when emitting
355     // a function call, since we need to polyfill GLSL-style out-parameter support. (skia:14130)
356     // The map holds <index-expression, temp-variable name>.
357     using IndexSubstitutionMap = skia_private::THashMap<const Expression*, std::string>;
358 
359     // When fIndexSubstitution is null (usually), index-substitution does not need to be performed.
360     struct IndexSubstitutionData {
361         IndexSubstitutionMap fMap;
362         StringStream fMainStream;
363         StringStream fPrefixStream;
364         bool fCreateSubstitutes = true;
365     };
366     std::unique_ptr<IndexSubstitutionData> fIndexSubstitutionData;
367     PrettyPrint fPrettyPrint;
368 
369     // Workaround/polyfill flags
370     bool fWrittenInverse2 = false, fWrittenInverse3 = false, fWrittenInverse4 = false;
371     bool fWrittenMatrixCompMult = false;
372     bool fWrittenOuterProduct = false;
373 };
374 
operator_name(Operator op)375 static const char* operator_name(Operator op) {
376     switch (op.kind()) {
377         case Operator::Kind::LOGICALXOR:  return " != ";
378         default:                          return op.operatorName();
379     }
380 }
381 
382 class MetalCodeGenerator::GlobalStructVisitor {
383 public:
384     virtual ~GlobalStructVisitor() = default;
visitInterfaceBlock(const InterfaceBlock & block,std::string_view blockName)385     virtual void visitInterfaceBlock(const InterfaceBlock& block, std::string_view blockName) {}
visitTexture(const Type & type,std::string_view name)386     virtual void visitTexture(const Type& type, std::string_view name) {}
visitSampler(const Type & type,std::string_view name)387     virtual void visitSampler(const Type& type, std::string_view name) {}
visitConstantVariable(const VarDeclaration & decl)388     virtual void visitConstantVariable(const VarDeclaration& decl) {}
visitNonconstantVariable(const Variable & var,const Expression * value)389     virtual void visitNonconstantVariable(const Variable& var, const Expression* value) {}
390 };
391 
392 class MetalCodeGenerator::ThreadgroupStructVisitor {
393 public:
394     virtual ~ThreadgroupStructVisitor() = default;
395     virtual void visitNonconstantVariable(const Variable& var) = 0;
396 };
397 
write(std::string_view s)398 void MetalCodeGenerator::write(std::string_view s) {
399     if (s.empty()) {
400         return;
401     }
402     if (fAtLineStart && fPrettyPrint == PrettyPrint::kYes) {
403         for (int i = 0; i < fIndentation; i++) {
404             fOut->writeText("    ");
405         }
406     }
407     fOut->writeText(std::string(s).c_str());
408     fAtLineStart = false;
409 }
410 
writeLine(std::string_view s)411 void MetalCodeGenerator::writeLine(std::string_view s) {
412     this->write(s);
413     fOut->writeText(fLineEnding);
414     fAtLineStart = true;
415 }
416 
finishLine()417 void MetalCodeGenerator::finishLine() {
418     if (!fAtLineStart) {
419         this->writeLine();
420     }
421 }
422 
writeExtension(const Extension & ext)423 void MetalCodeGenerator::writeExtension(const Extension& ext) {
424     this->writeLine("#extension " + std::string(ext.name()) + " : enable");
425 }
426 
typeName(const Type & raw)427 std::string MetalCodeGenerator::typeName(const Type& raw) {
428     // we need to know the modifiers for textures
429     const Type& type = raw.resolve().scalarTypeForLiteral();
430     switch (type.typeKind()) {
431         case Type::TypeKind::kArray: {
432             std::string typeName = this->typeName(type.componentType());
433             if (type.isUnsizedArray()) {
434                 return String::printf("const device %s*", typeName.c_str());
435             } else {
436                 SkASSERTF(type.columns() > 0, "invalid array size: %s", type.description().c_str());
437                 return String::printf("array<%s, %d>", typeName.c_str(), type.columns());
438             }
439         }
440         case Type::TypeKind::kVector:
441             return this->typeName(type.componentType()) + std::to_string(type.columns());
442 
443         case Type::TypeKind::kMatrix:
444             return this->typeName(type.componentType()) + std::to_string(type.columns()) + "x" +
445                                   std::to_string(type.rows());
446 
447         case Type::TypeKind::kSampler:
448             if (type.dimensions() != SpvDim2D) {
449                 fContext.fErrors->error(Position(), "Unsupported texture dimensions");
450             }
451             return "sampler2D";
452 
453         case Type::TypeKind::kTexture:
454             switch (type.textureAccess()) {
455                 case Type::TextureAccess::kSample:    return "texture2d<half>";
456                 case Type::TextureAccess::kRead:      return "texture2d<half, access::read>";
457                 case Type::TextureAccess::kWrite:     return "texture2d<half, access::write>";
458                 case Type::TextureAccess::kReadWrite: return "texture2d<half, access::read_write>";
459                 default:                              break;
460             }
461             SkUNREACHABLE;
462 
463         case Type::TypeKind::kAtomic:
464             // SkSL currently only supports the atomicUint type.
465             SkASSERT(type.matches(*fContext.fTypes.fAtomicUInt));
466             return "atomic_uint";
467 
468         default:
469             return std::string(type.name());
470     }
471 }
472 
writeStructDefinition(const StructDefinition & s)473 void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
474     const Type& type = s.type();
475     this->writeLine("struct " + type.displayName() + " {");
476     fIndentation++;
477     this->writeFields(type.fields(), type.fPosition);
478     fIndentation--;
479     this->writeLine("};");
480 }
481 
writeType(const Type & type)482 void MetalCodeGenerator::writeType(const Type& type) {
483     this->write(this->typeName(type));
484 }
485 
writeExpression(const Expression & expr,Precedence parentPrecedence)486 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
487     switch (expr.kind()) {
488         case Expression::Kind::kBinary:
489             this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
490             break;
491         case Expression::Kind::kConstructorArray:
492         case Expression::Kind::kConstructorStruct:
493             this->writeAnyConstructor(expr.asAnyConstructor(), "{", "}", parentPrecedence);
494             break;
495         case Expression::Kind::kConstructorArrayCast:
496             this->writeConstructorArrayCast(expr.as<ConstructorArrayCast>(), parentPrecedence);
497             break;
498         case Expression::Kind::kConstructorCompound:
499             this->writeConstructorCompound(expr.as<ConstructorCompound>(), parentPrecedence);
500             break;
501         case Expression::Kind::kConstructorDiagonalMatrix:
502         case Expression::Kind::kConstructorSplat:
503             this->writeAnyConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
504             break;
505         case Expression::Kind::kConstructorMatrixResize:
506             this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(),
507                                                parentPrecedence);
508             break;
509         case Expression::Kind::kConstructorScalarCast:
510         case Expression::Kind::kConstructorCompoundCast:
511             this->writeCastConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
512             break;
513         case Expression::Kind::kEmpty:
514             this->write("false");
515             break;
516         case Expression::Kind::kFieldAccess:
517             this->writeFieldAccess(expr.as<FieldAccess>());
518             break;
519         case Expression::Kind::kLiteral:
520             this->writeLiteral(expr.as<Literal>());
521             break;
522         case Expression::Kind::kFunctionCall:
523             this->writeFunctionCall(expr.as<FunctionCall>());
524             break;
525         case Expression::Kind::kPrefix:
526             this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
527             break;
528         case Expression::Kind::kPostfix:
529             this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
530             break;
531         case Expression::Kind::kSetting:
532             this->writeExpression(*expr.as<Setting>().toLiteral(fCaps), parentPrecedence);
533             break;
534         case Expression::Kind::kSwizzle:
535             this->writeSwizzle(expr.as<Swizzle>());
536             break;
537         case Expression::Kind::kVariableReference:
538             this->writeVariableReference(expr.as<VariableReference>());
539             break;
540         case Expression::Kind::kTernary:
541             this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
542             break;
543         case Expression::Kind::kIndex:
544             this->writeIndexExpression(expr.as<IndexExpression>());
545             break;
546         default:
547             SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
548             break;
549     }
550 }
551 
552 // returns true if we should pass by reference instead of by value
pass_by_reference(const Type & type,ModifierFlags flags)553 static bool pass_by_reference(const Type& type, ModifierFlags flags) {
554     return (flags & ModifierFlag::kOut) && !type.isUnsizedArray();
555 }
556 
557 // returns true if we need to specify an address space modifier
needs_address_space(const Type & type,ModifierFlags modifiers)558 static bool needs_address_space(const Type& type, ModifierFlags modifiers) {
559     return type.isUnsizedArray() || pass_by_reference(type, modifiers);
560 }
561 
562 // returns true if the InterfaceBlock has the `buffer` modifier
is_buffer(const InterfaceBlock & block)563 static bool is_buffer(const InterfaceBlock& block) {
564     return block.var()->modifierFlags().isBuffer();
565 }
566 
567 // returns true if the InterfaceBlock has the `readonly` modifier
is_readonly(const InterfaceBlock & block)568 static bool is_readonly(const InterfaceBlock& block) {
569     return block.var()->modifierFlags().isReadOnly();
570 }
571 
getBitcastIntrinsic(const Type & outType)572 std::string MetalCodeGenerator::getBitcastIntrinsic(const Type& outType) {
573     return "as_type<" +  outType.displayName() + ">";
574 }
575 
writeWithIndexSubstitution(const std::function<void ()> & fn)576 void MetalCodeGenerator::writeWithIndexSubstitution(const std::function<void()>& fn) {
577     auto oldIndexSubstitutionData = std::make_unique<IndexSubstitutionData>();
578     fIndexSubstitutionData.swap(oldIndexSubstitutionData);
579 
580     // Invoke our helper function, with output going into our temporary stream.
581     {
582         AutoOutputStream outputToMainStream(this, &fIndexSubstitutionData->fMainStream);
583         fn();
584     }
585 
586     if (fIndexSubstitutionData->fPrefixStream.bytesWritten() == 0) {
587         // Emit the main stream into the program as-is.
588         write_stringstream(fIndexSubstitutionData->fMainStream, *fOut);
589     } else {
590         // Emit the prefix stream and main stream into the program as a sequence-expression.
591         // (Each prefix-expression must end with a comma.)
592         this->write("(");
593         write_stringstream(fIndexSubstitutionData->fPrefixStream, *fOut);
594         write_stringstream(fIndexSubstitutionData->fMainStream, *fOut);
595         this->write(")");
596     }
597 
598     fIndexSubstitutionData.swap(oldIndexSubstitutionData);
599 }
600 
writeFunctionCall(const FunctionCall & c)601 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
602     const FunctionDeclaration& function = c.function();
603 
604     // Many intrinsics need to be rewritten in Metal.
605     if (function.isIntrinsic()) {
606         if (this->writeIntrinsicCall(c, function.intrinsicKind())) {
607             return;
608         }
609     }
610 
611     // Look for out parameters. SkSL guarantees GLSL's out-param semantics, and we need to emulate
612     // it if an out-param is encountered. (Specifically, out-parameters in GLSL are only written
613     // back to the original variable at the end of the function call; also, swizzles are supported,
614     // whereas Metal doesn't allow a swizzle to be passed to a `floatN&`.)
615     const ExpressionArray& arguments = c.arguments();
616     SkSpan<Variable* const> parameters = function.parameters();
617     SkASSERT(SkToSizeT(arguments.size()) == parameters.size());
618 
619     bool foundOutParam = false;
620     STArray<16, std::string> scratchVarName;
621     scratchVarName.push_back_n(arguments.size(), std::string());
622 
623     for (int index = 0; index < arguments.size(); ++index) {
624         // If this is an out parameter...
625         if (parameters[index]->modifierFlags() & ModifierFlag::kOut) {
626             // Assignability was verified at IRGeneration time, so this should always succeed.
627             [[maybe_unused]] Analysis::AssignmentInfo info;
628             SkASSERT(Analysis::IsAssignable(*arguments[index], &info));
629 
630             scratchVarName[index] = this->getTempVariable(arguments[index]->type());
631             foundOutParam = true;
632         }
633     }
634 
635     if (foundOutParam) {
636         // Out parameters need to be written back to at the end of the function. To do this, we
637         // generate a comma-separated sequence expression that copies the out-param expressions into
638         // our temporary variables, calls the original function--storing its result into a scratch
639         // variable--and then writes the temp variables back into the original out params using the
640         // original out-param expressions. This would look something like:
641         //
642         // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp), _skResult)
643         //       ^                     ^                                     ^                ^
644         //   return value       passes copy of argument    copies back into argument    return value
645         //
646         // While these expressions are complex, they allow us to maintain the proper sequencing that
647         // is necessary for out-parameters, as well as allowing us to support things like swizzles
648         // and array indices which Metal references cannot natively handle.
649 
650         // We will be emitting inout expressions twice, so it's important to enable index
651         // substitution in case we encounter any side-effecting indexes.
652         this->writeWithIndexSubstitution([&] {
653             this->write("((");
654 
655             // ((_skResult =
656             std::string scratchResultName;
657             if (!function.returnType().isVoid()) {
658                 scratchResultName = this->getTempVariable(c.type());
659                 this->write(scratchResultName);
660                 this->write(" = ");
661             }
662 
663             // ((_skResult = func(
664             this->write(function.mangledName());
665             this->write("(");
666 
667             // ((_skResult = func((_skTemp = myOutParam.x), 123
668             const char* separator = "";
669             this->writeFunctionRequirementArgs(function, separator);
670 
671             for (int i = 0; i < arguments.size(); ++i) {
672                 this->write(separator);
673                 separator = ", ";
674                 if (parameters[i]->modifierFlags() & ModifierFlag::kOut) {
675                     SkASSERT(!scratchVarName[i].empty());
676                     if (parameters[i]->modifierFlags() & ModifierFlag::kIn) {
677                         // `inout` parameters initialize the scratch variable with the passed-in
678                         // argument's value.
679                         this->write("(");
680                         this->write(scratchVarName[i]);
681                         this->write(" = ");
682                         this->writeExpression(*arguments[i], Precedence::kAssignment);
683                         this->write(")");
684                     } else {
685                         // `out` parameters pass a reference to the uninitialized scratch variable.
686                         this->write(scratchVarName[i]);
687                     }
688                 } else {
689                     // Regular parameters are passed as-is.
690                     this->writeExpression(*arguments[i], Precedence::kSequence);
691                 }
692             }
693 
694             // ((_skResult = func((_skTemp = myOutParam.x), 123))
695             this->write("))");
696 
697             // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp)
698             for (int i = 0; i < arguments.size(); ++i) {
699                 if (!scratchVarName[i].empty()) {
700                     this->write(", (");
701                     this->writeExpression(*arguments[i], Precedence::kAssignment);
702                     this->write(" = ");
703                     this->write(scratchVarName[i]);
704                     this->write(")");
705                 }
706             }
707 
708             // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp),
709             //                                                     _skResult
710             if (!scratchResultName.empty()) {
711                 this->write(", ");
712                 this->write(scratchResultName);
713             }
714 
715             // ((_skResult = func((_skTemp = myOutParam.x), 123)), (myOutParam.x = _skTemp),
716             //                                                     _skResult)
717             this->write(")");
718         });
719     } else {
720         // Emit the function call as-is, only prepending the required arguments.
721         this->write(function.mangledName());
722         this->write("(");
723         const char* separator = "";
724         this->writeFunctionRequirementArgs(function, separator);
725         for (int i = 0; i < arguments.size(); ++i) {
726             SkASSERT(scratchVarName[i].empty());
727             this->write(separator);
728             separator = ", ";
729             this->writeExpression(*arguments[i], Precedence::kSequence);
730         }
731         this->write(")");
732     }
733 }
734 
735 static constexpr char kInverse2x2[] = R"(
736 template <typename T>
737 matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
738 return matrix<T, 2, 2>(m[1].y, -m[0].y, -m[1].x, m[0].x) * (1/determinant(m));
739 }
740 )";
741 
742 static constexpr char kInverse3x3[] = R"(
743 template <typename T>
744 matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
745 T
746  a00 = m[0].x, a01 = m[0].y, a02 = m[0].z,
747  a10 = m[1].x, a11 = m[1].y, a12 = m[1].z,
748  a20 = m[2].x, a21 = m[2].y, a22 = m[2].z,
749  b01 =  a22*a11 - a12*a21,
750  b11 = -a22*a10 + a12*a20,
751  b21 =  a21*a10 - a11*a20,
752  det = a00*b01 + a01*b11 + a02*b21;
753 return matrix<T, 3, 3>(
754  b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
755  b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
756  b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
757 }
758 )";
759 
760 static constexpr char kInverse4x4[] = R"(
761 template <typename T>
762 matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
763 T
764  a00 = m[0].x, a01 = m[0].y, a02 = m[0].z, a03 = m[0].w,
765  a10 = m[1].x, a11 = m[1].y, a12 = m[1].z, a13 = m[1].w,
766  a20 = m[2].x, a21 = m[2].y, a22 = m[2].z, a23 = m[2].w,
767  a30 = m[3].x, a31 = m[3].y, a32 = m[3].z, a33 = m[3].w,
768  b00 = a00*a11 - a01*a10,
769  b01 = a00*a12 - a02*a10,
770  b02 = a00*a13 - a03*a10,
771  b03 = a01*a12 - a02*a11,
772  b04 = a01*a13 - a03*a11,
773  b05 = a02*a13 - a03*a12,
774  b06 = a20*a31 - a21*a30,
775  b07 = a20*a32 - a22*a30,
776  b08 = a20*a33 - a23*a30,
777  b09 = a21*a32 - a22*a31,
778  b10 = a21*a33 - a23*a31,
779  b11 = a22*a33 - a23*a32,
780  det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
781 return matrix<T, 4, 4>(
782  a11*b11 - a12*b10 + a13*b09,
783  a02*b10 - a01*b11 - a03*b09,
784  a31*b05 - a32*b04 + a33*b03,
785  a22*b04 - a21*b05 - a23*b03,
786  a12*b08 - a10*b11 - a13*b07,
787  a00*b11 - a02*b08 + a03*b07,
788  a32*b02 - a30*b05 - a33*b01,
789  a20*b05 - a22*b02 + a23*b01,
790  a10*b10 - a11*b08 + a13*b06,
791  a01*b08 - a00*b10 - a03*b06,
792  a30*b04 - a31*b02 + a33*b00,
793  a21*b02 - a20*b04 - a23*b00,
794  a11*b07 - a10*b09 - a12*b06,
795  a00*b09 - a01*b07 + a02*b06,
796  a31*b01 - a30*b03 - a32*b00,
797  a20*b03 - a21*b01 + a22*b00) * (1/det);
798 }
799 )";
800 
getInversePolyfill(const ExpressionArray & arguments)801 std::string MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments) {
802     // Only use polyfills for a function taking a single-argument square matrix.
803     SkASSERT(arguments.size() == 1);
804     const Type& type = arguments.front()->type();
805     if (type.isMatrix() && type.rows() == type.columns()) {
806         switch (type.rows()) {
807             case 2:
808                 if (!fWrittenInverse2) {
809                     fWrittenInverse2 = true;
810                     fExtraFunctions.writeText(kInverse2x2);
811                 }
812                 return "mat2_inverse";
813             case 3:
814                 if (!fWrittenInverse3) {
815                     fWrittenInverse3 = true;
816                     fExtraFunctions.writeText(kInverse3x3);
817                 }
818                 return "mat3_inverse";
819             case 4:
820                 if (!fWrittenInverse4) {
821                     fWrittenInverse4 = true;
822                     fExtraFunctions.writeText(kInverse4x4);
823                 }
824                 return "mat4_inverse";
825         }
826     }
827     SkDEBUGFAILF("no polyfill for inverse(%s)", type.description().c_str());
828     return "inverse";
829 }
830 
writeMatrixCompMult()831 void MetalCodeGenerator::writeMatrixCompMult() {
832     static constexpr char kMatrixCompMult[] = R"(
833 template <typename T, int C, int R>
834 matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
835  for (int c = 0; c < C; ++c) { a[c] *= b[c]; }
836  return a;
837 }
838 )";
839     if (!fWrittenMatrixCompMult) {
840         fWrittenMatrixCompMult = true;
841         fExtraFunctions.writeText(kMatrixCompMult);
842     }
843 }
844 
writeOuterProduct()845 void MetalCodeGenerator::writeOuterProduct() {
846     static constexpr char kOuterProduct[] = R"(
847 template <typename T, int C, int R>
848 matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
849  matrix<T, C, R> m;
850  for (int c = 0; c < C; ++c) { m[c] = a * b[c]; }
851  return m;
852 }
853 )";
854     if (!fWrittenOuterProduct) {
855         fWrittenOuterProduct = true;
856         fExtraFunctions.writeText(kOuterProduct);
857     }
858 }
859 
getTempVariable(const Type & type)860 std::string MetalCodeGenerator::getTempVariable(const Type& type) {
861     std::string tempVar = "_skTemp" + std::to_string(fVarCount++);
862     this->fFunctionHeader += "    " + this->typeName(type) + " " + tempVar + ";\n";
863     return tempVar;
864 }
865 
writeSimpleIntrinsic(const FunctionCall & c)866 void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
867     // Write out an intrinsic function call exactly as-is. No muss no fuss.
868     this->write(c.function().name());
869     this->writeArgumentList(c.arguments());
870 }
871 
writeArgumentList(const ExpressionArray & arguments)872 void MetalCodeGenerator::writeArgumentList(const ExpressionArray& arguments) {
873     this->write("(");
874     const char* separator = "";
875     for (const std::unique_ptr<Expression>& arg : arguments) {
876         this->write(separator);
877         separator = ", ";
878         this->writeExpression(*arg, Precedence::kSequence);
879     }
880     this->write(")");
881 }
882 
writeIntrinsicCall(const FunctionCall & c,IntrinsicKind kind)883 bool MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind) {
884     const ExpressionArray& arguments = c.arguments();
885     switch (kind) {
886         case k_textureRead_IntrinsicKind: {
887             this->writeExpression(*arguments[0], Precedence::kExpression);
888             this->write(".read(");
889             this->writeExpression(*arguments[1], Precedence::kSequence);
890             this->write(")");
891             return true;
892         }
893         case k_textureWrite_IntrinsicKind: {
894             this->writeExpression(*arguments[0], Precedence::kExpression);
895             this->write(".write(");
896             this->writeExpression(*arguments[2], Precedence::kSequence);
897             this->write(", ");
898             this->writeExpression(*arguments[1], Precedence::kSequence);
899             this->write(")");
900             return true;
901         }
902         case k_textureWidth_IntrinsicKind: {
903             this->writeExpression(*arguments[0], Precedence::kExpression);
904             this->write(".get_width()");
905             return true;
906         }
907         case k_textureHeight_IntrinsicKind: {
908             this->writeExpression(*arguments[0], Precedence::kExpression);
909             this->write(".get_height()");
910             return true;
911         }
912         case k_mod_IntrinsicKind: {
913             // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
914             std::string tmpX = this->getTempVariable(arguments[0]->type());
915             std::string tmpY = this->getTempVariable(arguments[1]->type());
916             this->write("(" + tmpX + " = ");
917             this->writeExpression(*arguments[0], Precedence::kSequence);
918             this->write(", " + tmpY + " = ");
919             this->writeExpression(*arguments[1], Precedence::kSequence);
920             this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))");
921             return true;
922         }
923         case k_pow_IntrinsicKind: {
924             // The Metal equivalent to GLSL's pow() is powr(). Metal's pow() allows negative base
925             // values, which is presumably more expensive to compute.
926             this->write("powr(");
927             this->writeExpression(*arguments[0], Precedence::kSequence);
928             this->write(", ");
929             this->writeExpression(*arguments[1], Precedence::kSequence);
930             this->write(")");
931             return true;
932         }
933         // GLSL declares scalar versions of most geometric intrinsics, but these don't exist in MSL
934         case k_distance_IntrinsicKind: {
935             if (arguments[0]->type().columns() == 1) {
936                 this->write("abs(");
937                 this->writeExpression(*arguments[0], Precedence::kAdditive);
938                 this->write(" - ");
939                 this->writeExpression(*arguments[1], Precedence::kAdditive);
940                 this->write(")");
941             } else {
942                 this->writeSimpleIntrinsic(c);
943             }
944             return true;
945         }
946         case k_dot_IntrinsicKind: {
947             if (arguments[0]->type().columns() == 1) {
948                 this->write("(");
949                 this->writeExpression(*arguments[0], Precedence::kMultiplicative);
950                 this->write(" * ");
951                 this->writeExpression(*arguments[1], Precedence::kMultiplicative);
952                 this->write(")");
953             } else {
954                 this->writeSimpleIntrinsic(c);
955             }
956             return true;
957         }
958         case k_faceforward_IntrinsicKind: {
959             if (arguments[0]->type().columns() == 1) {
960                 // ((((Nref) * (I) < 0) ? 1 : -1) * (N))
961                 this->write("((((");
962                 this->writeExpression(*arguments[2], Precedence::kSequence);
963                 this->write(") * (");
964                 this->writeExpression(*arguments[1], Precedence::kSequence);
965                 this->write(") < 0) ? 1 : -1) * (");
966                 this->writeExpression(*arguments[0], Precedence::kSequence);
967                 this->write("))");
968             } else {
969                 this->writeSimpleIntrinsic(c);
970             }
971             return true;
972         }
973         case k_length_IntrinsicKind: {
974             this->write(arguments[0]->type().columns() == 1 ? "abs(" : "length(");
975             this->writeExpression(*arguments[0], Precedence::kSequence);
976             this->write(")");
977             return true;
978         }
979         case k_normalize_IntrinsicKind: {
980             this->write(arguments[0]->type().columns() == 1 ? "sign(" : "normalize(");
981             this->writeExpression(*arguments[0], Precedence::kSequence);
982             this->write(")");
983             return true;
984         }
985         case k_packUnorm2x16_IntrinsicKind: {
986             this->write("pack_float_to_unorm2x16(");
987             this->writeExpression(*arguments[0], Precedence::kSequence);
988             this->write(")");
989             return true;
990         }
991         case k_unpackUnorm2x16_IntrinsicKind: {
992             this->write("unpack_unorm2x16_to_float(");
993             this->writeExpression(*arguments[0], Precedence::kSequence);
994             this->write(")");
995             return true;
996         }
997         case k_packSnorm2x16_IntrinsicKind: {
998             this->write("pack_float_to_snorm2x16(");
999             this->writeExpression(*arguments[0], Precedence::kSequence);
1000             this->write(")");
1001             return true;
1002         }
1003         case k_unpackSnorm2x16_IntrinsicKind: {
1004             this->write("unpack_snorm2x16_to_float(");
1005             this->writeExpression(*arguments[0], Precedence::kSequence);
1006             this->write(")");
1007             return true;
1008         }
1009         case k_packUnorm4x8_IntrinsicKind: {
1010             this->write("pack_float_to_unorm4x8(");
1011             this->writeExpression(*arguments[0], Precedence::kSequence);
1012             this->write(")");
1013             return true;
1014         }
1015         case k_unpackUnorm4x8_IntrinsicKind: {
1016             this->write("unpack_unorm4x8_to_float(");
1017             this->writeExpression(*arguments[0], Precedence::kSequence);
1018             this->write(")");
1019             return true;
1020         }
1021         case k_packSnorm4x8_IntrinsicKind: {
1022             this->write("pack_float_to_snorm4x8(");
1023             this->writeExpression(*arguments[0], Precedence::kSequence);
1024             this->write(")");
1025             return true;
1026         }
1027         case k_unpackSnorm4x8_IntrinsicKind: {
1028             this->write("unpack_snorm4x8_to_float(");
1029             this->writeExpression(*arguments[0], Precedence::kSequence);
1030             this->write(")");
1031             return true;
1032         }
1033         case k_packHalf2x16_IntrinsicKind: {
1034             this->write("as_type<uint>(half2(");
1035             this->writeExpression(*arguments[0], Precedence::kSequence);
1036             this->write("))");
1037             return true;
1038         }
1039         case k_unpackHalf2x16_IntrinsicKind: {
1040             this->write("float2(as_type<half2>(");
1041             this->writeExpression(*arguments[0], Precedence::kSequence);
1042             this->write("))");
1043             return true;
1044         }
1045         case k_floatBitsToInt_IntrinsicKind:
1046         case k_floatBitsToUint_IntrinsicKind:
1047         case k_intBitsToFloat_IntrinsicKind:
1048         case k_uintBitsToFloat_IntrinsicKind: {
1049             this->write(this->getBitcastIntrinsic(c.type()));
1050             this->write("(");
1051             this->writeExpression(*arguments[0], Precedence::kSequence);
1052             this->write(")");
1053             return true;
1054         }
1055         case k_degrees_IntrinsicKind: {
1056             this->write("((");
1057             this->writeExpression(*arguments[0], Precedence::kSequence);
1058             this->write(") * 57.2957795)");
1059             return true;
1060         }
1061         case k_radians_IntrinsicKind: {
1062             this->write("((");
1063             this->writeExpression(*arguments[0], Precedence::kSequence);
1064             this->write(") * 0.0174532925)");
1065             return true;
1066         }
1067         case k_dFdx_IntrinsicKind: {
1068             this->write("dfdx");
1069             this->writeArgumentList(c.arguments());
1070             return true;
1071         }
1072         case k_dFdy_IntrinsicKind: {
1073             if (!fRTFlipName.empty()) {
1074                 this->write("(" + fRTFlipName + ".y * dfdy");
1075             } else {
1076                 this->write("(dfdy");
1077             }
1078             this->writeArgumentList(c.arguments());
1079             this->write(")");
1080             return true;
1081         }
1082         case k_inverse_IntrinsicKind: {
1083             this->write(this->getInversePolyfill(arguments));
1084             this->writeArgumentList(c.arguments());
1085             return true;
1086         }
1087         case k_inversesqrt_IntrinsicKind: {
1088             this->write("rsqrt");
1089             this->writeArgumentList(c.arguments());
1090             return true;
1091         }
1092         case k_atan_IntrinsicKind: {
1093             this->write(c.arguments().size() == 2 ? "atan2" : "atan");
1094             this->writeArgumentList(c.arguments());
1095             return true;
1096         }
1097         case k_reflect_IntrinsicKind: {
1098             if (arguments[0]->type().columns() == 1) {
1099                 // We need to synthesize `I - 2 * N * I * N`.
1100                 std::string tmpI = this->getTempVariable(arguments[0]->type());
1101                 std::string tmpN = this->getTempVariable(arguments[1]->type());
1102 
1103                 // (_skTempI = ...
1104                 this->write("(" + tmpI + " = ");
1105                 this->writeExpression(*arguments[0], Precedence::kSequence);
1106 
1107                 // , _skTempN = ...
1108                 this->write(", " + tmpN + " = ");
1109                 this->writeExpression(*arguments[1], Precedence::kSequence);
1110 
1111                 // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
1112                 this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
1113             } else {
1114                 this->writeSimpleIntrinsic(c);
1115             }
1116             return true;
1117         }
1118         case k_refract_IntrinsicKind: {
1119             if (arguments[0]->type().columns() == 1) {
1120                 // Metal does implement refract for vectors; rather than reimplementing refract from
1121                 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
1122                 this->write("(refract(float2(");
1123                 this->writeExpression(*arguments[0], Precedence::kSequence);
1124                 this->write(", 0), float2(");
1125                 this->writeExpression(*arguments[1], Precedence::kSequence);
1126                 this->write(", 0), ");
1127                 this->writeExpression(*arguments[2], Precedence::kSequence);
1128                 this->write(").x)");
1129             } else {
1130                 this->writeSimpleIntrinsic(c);
1131             }
1132             return true;
1133         }
1134         case k_roundEven_IntrinsicKind: {
1135             this->write("rint");
1136             this->writeArgumentList(c.arguments());
1137             return true;
1138         }
1139         case k_bitCount_IntrinsicKind: {
1140             this->write("popcount(");
1141             this->writeExpression(*arguments[0], Precedence::kSequence);
1142             this->write(")");
1143             return true;
1144         }
1145         case k_findLSB_IntrinsicKind: {
1146             // Create a temp variable to store the expression, to avoid double-evaluating it.
1147             std::string skTemp = this->getTempVariable(arguments[0]->type());
1148             std::string exprType = this->typeName(arguments[0]->type());
1149 
1150             // ctz returns numbits(type) on zero inputs; GLSL documents it as generating -1 instead.
1151             // Use select to detect zero inputs and force a -1 result.
1152 
1153             // (_skTemp1 = (.....), select(ctz(_skTemp1), int4(-1), _skTemp1 == int4(0)))
1154             this->write("(");
1155             this->write(skTemp);
1156             this->write(" = (");
1157             this->writeExpression(*arguments[0], Precedence::kSequence);
1158             this->write("), select(ctz(");
1159             this->write(skTemp);
1160             this->write("), ");
1161             this->write(exprType);
1162             this->write("(-1), ");
1163             this->write(skTemp);
1164             this->write(" == ");
1165             this->write(exprType);
1166             this->write("(0)))");
1167             return true;
1168         }
1169         case k_findMSB_IntrinsicKind: {
1170             // Create a temp variable to store the expression, to avoid double-evaluating it.
1171             std::string skTemp1 = this->getTempVariable(arguments[0]->type());
1172             std::string exprType = this->typeName(arguments[0]->type());
1173 
1174             // GLSL findMSB is actually quite different from Metal's clz:
1175             // - For signed negative numbers, it returns the first zero bit, not the first one bit!
1176             // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
1177 
1178             // (_skTemp1 = (.....),
1179             this->write("(");
1180             this->write(skTemp1);
1181             this->write(" = (");
1182             this->writeExpression(*arguments[0], Precedence::kSequence);
1183             this->write("), ");
1184 
1185             // Signed input types might be negative; we need another helper variable to negate the
1186             // input (since we can only find one bits, not zero bits).
1187             std::string skTemp2;
1188             if (arguments[0]->type().isSigned()) {
1189                 // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
1190                 skTemp2 = this->getTempVariable(arguments[0]->type());
1191                 this->write(skTemp2);
1192                 this->write(" = (select(");
1193                 this->write(skTemp1);
1194                 this->write(", ~");
1195                 this->write(skTemp1);
1196                 this->write(", ");
1197                 this->write(skTemp1);
1198                 this->write(" < 0)), ");
1199             } else {
1200                 skTemp2 = skTemp1;
1201             }
1202 
1203             // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
1204             this->write("select(");
1205             this->write(this->typeName(c.type()));
1206             this->write("(clz(");
1207             this->write(skTemp2);
1208             this->write(")), ");
1209             this->write(this->typeName(c.type()));
1210             this->write("(-1), ");
1211             this->write(skTemp2);
1212             this->write(" == ");
1213             this->write(exprType);
1214             this->write("(0)))");
1215             return true;
1216         }
1217         case k_sign_IntrinsicKind: {
1218             if (arguments[0]->type().componentType().isInteger()) {
1219                 // Create a temp variable to store the expression, to avoid double-evaluating it.
1220                 std::string skTemp = this->getTempVariable(arguments[0]->type());
1221                 std::string exprType = this->typeName(arguments[0]->type());
1222 
1223                 // (_skTemp = (.....),
1224                 this->write("(");
1225                 this->write(skTemp);
1226                 this->write(" = (");
1227                 this->writeExpression(*arguments[0], Precedence::kSequence);
1228                 this->write("), ");
1229 
1230                 // ... select(select(int4(0), int4(-1), _skTemp < 0), int4(1), _skTemp > 0))
1231                 this->write("select(select(");
1232                 this->write(exprType);
1233                 this->write("(0), ");
1234                 this->write(exprType);
1235                 this->write("(-1), ");
1236                 this->write(skTemp);
1237                 this->write(" < 0), ");
1238                 this->write(exprType);
1239                 this->write("(1), ");
1240                 this->write(skTemp);
1241                 this->write(" > 0))");
1242             } else {
1243                 this->writeSimpleIntrinsic(c);
1244             }
1245             return true;
1246         }
1247         case k_matrixCompMult_IntrinsicKind: {
1248             this->writeMatrixCompMult();
1249             this->writeSimpleIntrinsic(c);
1250             return true;
1251         }
1252         case k_outerProduct_IntrinsicKind: {
1253             this->writeOuterProduct();
1254             this->writeSimpleIntrinsic(c);
1255             return true;
1256         }
1257         case k_mix_IntrinsicKind: {
1258             SkASSERT(c.arguments().size() == 3);
1259             if (arguments[2]->type().componentType().isBoolean()) {
1260                 // The Boolean forms of GLSL mix() use the select() intrinsic in Metal.
1261                 this->write("select");
1262                 this->writeArgumentList(c.arguments());
1263                 return true;
1264             }
1265             // The basic form of mix() is supported by Metal as-is.
1266             this->writeSimpleIntrinsic(c);
1267             return true;
1268         }
1269         case k_equal_IntrinsicKind:
1270         case k_greaterThan_IntrinsicKind:
1271         case k_greaterThanEqual_IntrinsicKind:
1272         case k_lessThan_IntrinsicKind:
1273         case k_lessThanEqual_IntrinsicKind:
1274         case k_notEqual_IntrinsicKind: {
1275             this->write("(");
1276             this->writeExpression(*c.arguments()[0], Precedence::kRelational);
1277             switch (kind) {
1278                 case k_equal_IntrinsicKind:
1279                     this->write(" == ");
1280                     break;
1281                 case k_notEqual_IntrinsicKind:
1282                     this->write(" != ");
1283                     break;
1284                 case k_lessThan_IntrinsicKind:
1285                     this->write(" < ");
1286                     break;
1287                 case k_lessThanEqual_IntrinsicKind:
1288                     this->write(" <= ");
1289                     break;
1290                 case k_greaterThan_IntrinsicKind:
1291                     this->write(" > ");
1292                     break;
1293                 case k_greaterThanEqual_IntrinsicKind:
1294                     this->write(" >= ");
1295                     break;
1296                 default:
1297                     SK_ABORT("unsupported comparison intrinsic kind");
1298             }
1299             this->writeExpression(*c.arguments()[1], Precedence::kRelational);
1300             this->write(")");
1301             return true;
1302         }
1303         case k_storageBarrier_IntrinsicKind:
1304             this->write("threadgroup_barrier(mem_flags::mem_device)");
1305             return true;
1306         case k_workgroupBarrier_IntrinsicKind:
1307             this->write("threadgroup_barrier(mem_flags::mem_threadgroup)");
1308             return true;
1309         case k_atomicAdd_IntrinsicKind:
1310             this->write("atomic_fetch_add_explicit(&");
1311             this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1312             this->write(", ");
1313             this->writeExpression(*c.arguments()[1], Precedence::kSequence);
1314             this->write(", memory_order_relaxed)");
1315             return true;
1316         case k_atomicLoad_IntrinsicKind:
1317             this->write("atomic_load_explicit(&");
1318             this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1319             this->write(", memory_order_relaxed)");
1320             return true;
1321         case k_atomicStore_IntrinsicKind:
1322             this->write("atomic_store_explicit(&");
1323             this->writeExpression(*c.arguments()[0], Precedence::kSequence);
1324             this->write(", ");
1325             this->writeExpression(*c.arguments()[1], Precedence::kSequence);
1326             this->write(", memory_order_relaxed)");
1327             return true;
1328         default:
1329             return false;
1330     }
1331 }
1332 
1333 // Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
1334 // Cells that don't exist in the source matrix will be populated with identity-matrix values.
assembleMatrixFromMatrix(const Type & sourceMatrix,int columns,int rows)1335 void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int columns, int rows) {
1336     SkASSERT(rows <= 4);
1337     SkASSERT(columns <= 4);
1338 
1339     std::string matrixType = this->typeName(sourceMatrix.componentType());
1340 
1341     const char* separator = "";
1342     for (int c = 0; c < columns; ++c) {
1343         fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1344         separator = "), ";
1345 
1346         // Determine how many values to take from the source matrix for this row.
1347         int swizzleLength = 0;
1348         if (c < sourceMatrix.columns()) {
1349             swizzleLength = std::min<>(rows, sourceMatrix.rows());
1350         }
1351 
1352         // Emit all the values from the source matrix row.
1353         bool firstItem;
1354         switch (swizzleLength) {
1355             case 0:  firstItem = true;                                            break;
1356             case 1:  firstItem = false; fExtraFunctions.printf("x0[%d].x", c);    break;
1357             case 2:  firstItem = false; fExtraFunctions.printf("x0[%d].xy", c);   break;
1358             case 3:  firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c);  break;
1359             case 4:  firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
1360             default: SkUNREACHABLE;
1361         }
1362 
1363         // Emit the placeholder identity-matrix cells.
1364         for (int r = swizzleLength; r < rows; ++r) {
1365             fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
1366             firstItem = false;
1367         }
1368     }
1369 
1370     fExtraFunctions.writeText(")");
1371 }
1372 
1373 // Assembles a matrix of type floatCxR by concatenating an arbitrary mix of values, named `x0`,
1374 // `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
assembleMatrixFromExpressions(const AnyConstructor & ctor,int columns,int rows)1375 void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
1376                                                        int columns,
1377                                                        int rows) {
1378     SkASSERT(rows <= 4);
1379     SkASSERT(columns <= 4);
1380 
1381     std::string matrixType = this->typeName(ctor.type().componentType());
1382     size_t argIndex = 0;
1383     int argPosition = 0;
1384     auto args = ctor.argumentSpan();
1385 
1386     static constexpr char kSwizzle[] = "xyzw";
1387     const char* separator = "";
1388     for (int c = 0; c < columns; ++c) {
1389         fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1390         separator = "), ";
1391 
1392         const char* columnSeparator = "";
1393         for (int r = 0; r < rows;) {
1394             fExtraFunctions.writeText(columnSeparator);
1395             columnSeparator = ", ";
1396 
1397             if (argIndex < args.size()) {
1398                 const Type& argType = args[argIndex]->type();
1399                 switch (argType.typeKind()) {
1400                     case Type::TypeKind::kScalar: {
1401                         fExtraFunctions.printf("x%zu", argIndex);
1402                         ++r;
1403                         ++argPosition;
1404                         break;
1405                     }
1406                     case Type::TypeKind::kVector: {
1407                         fExtraFunctions.printf("x%zu.", argIndex);
1408                         do {
1409                             fExtraFunctions.write8(kSwizzle[argPosition]);
1410                             ++r;
1411                             ++argPosition;
1412                         } while (r < rows && argPosition < argType.columns());
1413                         break;
1414                     }
1415                     case Type::TypeKind::kMatrix: {
1416                         fExtraFunctions.printf("x%zu[%d].", argIndex, argPosition / argType.rows());
1417                         do {
1418                             fExtraFunctions.write8(kSwizzle[argPosition]);
1419                             ++r;
1420                             ++argPosition;
1421                         } while (r < rows && (argPosition % argType.rows()) != 0);
1422                         break;
1423                     }
1424                     default: {
1425                         SkDEBUGFAIL("incorrect type of argument for matrix constructor");
1426                         fExtraFunctions.writeText("<error>");
1427                         break;
1428                     }
1429                 }
1430 
1431                 if (argPosition >= argType.columns() * argType.rows()) {
1432                     ++argIndex;
1433                     argPosition = 0;
1434                 }
1435             } else {
1436                 SkDEBUGFAIL("not enough arguments for matrix constructor");
1437                 fExtraFunctions.writeText("<error>");
1438             }
1439         }
1440     }
1441 
1442     if (argPosition != 0 || argIndex != args.size()) {
1443         SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
1444         fExtraFunctions.writeText(", <error>");
1445     }
1446 
1447     fExtraFunctions.writeText(")");
1448 }
1449 
1450 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
1451 // Keeps track of previously generated constructors so that we won't generate more than one
1452 // constructor for any given permutation of input argument types. Returns the name of the
1453 // generated constructor method.
getMatrixConstructHelper(const AnyConstructor & c)1454 std::string MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
1455     const Type& type = c.type();
1456     int columns = type.columns();
1457     int rows = type.rows();
1458     auto args = c.argumentSpan();
1459     std::string typeName = this->typeName(type);
1460 
1461     // Create the helper-method name and use it as our lookup key.
1462     std::string name = String::printf("%s_from", typeName.c_str());
1463     for (const std::unique_ptr<Expression>& expr : args) {
1464         String::appendf(&name, "_%s", this->typeName(expr->type()).c_str());
1465     }
1466 
1467     // If a helper-method has not been synthesized yet, create it now.
1468     if (!fHelpers.contains(name)) {
1469         fHelpers.add(name);
1470 
1471         // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
1472         // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
1473         // supply a mixture of scalars and vectors.)
1474         fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
1475 
1476         size_t argIndex = 0;
1477         const char* argSeparator = "";
1478         for (const std::unique_ptr<Expression>& expr : args) {
1479             fExtraFunctions.printf("%s%s x%zu", argSeparator,
1480                                    this->typeName(expr->type()).c_str(), argIndex++);
1481             argSeparator = ", ";
1482         }
1483 
1484         fExtraFunctions.printf(") {\n    return %s(", typeName.c_str());
1485 
1486         if (args.size() == 1 && args.front()->type().isMatrix()) {
1487             this->assembleMatrixFromMatrix(args.front()->type(), columns, rows);
1488         } else {
1489             this->assembleMatrixFromExpressions(c, columns, rows);
1490         }
1491 
1492         fExtraFunctions.writeText(");\n}\n");
1493     }
1494     return name;
1495 }
1496 
matrixConstructHelperIsNeeded(const ConstructorCompound & c)1497 bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const ConstructorCompound& c) {
1498     SkASSERT(c.type().isMatrix());
1499 
1500     // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it
1501     // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C
1502     // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of
1503     // scalars can be constructed trivially. In more complex cases, we generate a helper function
1504     // that converts our inputs into a properly-shaped matrix.
1505     // A matrix construct helper method is always used if any input argument is a matrix.
1506     // Helper methods are also necessary when any argument would span multiple rows. For instance:
1507     //
1508     // float2 x = (1, 2);
1509     // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline
1510     //                           | 2 4 6 |
1511     //
1512     // float2 x = (2, 3);
1513     // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used
1514     //                           | 2 4 6 |
1515     //
1516     // float4 x = (1, 2, 3, 4);
1517     // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used
1518     //               | 2 4 |
1519     //
1520 
1521     int position = 0;
1522     for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1523         // If an input argument is a matrix, we need a helper function.
1524         if (expr->type().isMatrix()) {
1525             return true;
1526         }
1527         position += expr->type().columns();
1528         if (position > c.type().rows()) {
1529             // An input argument would span multiple rows; a helper function is required.
1530             return true;
1531         }
1532         if (position == c.type().rows()) {
1533             // We've advanced to the end of a row. Wrap to the start of the next row.
1534             position = 0;
1535         }
1536     }
1537 
1538     return false;
1539 }
1540 
writeConstructorMatrixResize(const ConstructorMatrixResize & c,Precedence parentPrecedence)1541 void MetalCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1542                                                       Precedence parentPrecedence) {
1543     // Matrix-resize via casting doesn't natively exist in Metal at all, so we always need to use a
1544     // matrix-construct helper here.
1545     this->write(this->getMatrixConstructHelper(c));
1546     this->write("(");
1547     this->writeExpression(*c.argument(), Precedence::kSequence);
1548     this->write(")");
1549 }
1550 
writeConstructorCompound(const ConstructorCompound & c,Precedence parentPrecedence)1551 void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1552                                                   Precedence parentPrecedence) {
1553     if (c.type().isVector()) {
1554         this->writeConstructorCompoundVector(c, parentPrecedence);
1555     } else if (c.type().isMatrix()) {
1556         this->writeConstructorCompoundMatrix(c, parentPrecedence);
1557     } else {
1558         fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
1559     }
1560 }
1561 
writeConstructorArrayCast(const ConstructorArrayCast & c,Precedence parentPrecedence)1562 void MetalCodeGenerator::writeConstructorArrayCast(const ConstructorArrayCast& c,
1563                                                    Precedence parentPrecedence) {
1564     const Type& inType = c.argument()->type().componentType();
1565     const Type& outType = c.type().componentType();
1566     std::string inTypeName = this->typeName(inType);
1567     std::string outTypeName = this->typeName(outType);
1568 
1569     std::string name = "array_of_" + outTypeName + "_from_" + inTypeName;
1570     if (!fHelpers.contains(name)) {
1571         fHelpers.add(name);
1572         fExtraFunctions.printf(R"(
1573 template <size_t N>
1574 array<%s, N> %s(thread const array<%s, N>& x) {
1575     array<%s, N> result;
1576     for (int i = 0; i < N; ++i) {
1577         result[i] = %s(x[i]);
1578     }
1579     return result;
1580 }
1581 )",
1582                                outTypeName.c_str(), name.c_str(), inTypeName.c_str(),
1583                                outTypeName.c_str(),
1584                                outTypeName.c_str());
1585     }
1586 
1587     this->write(name);
1588     this->write("(");
1589     this->writeExpression(*c.argument(), Precedence::kSequence);
1590     this->write(")");
1591 }
1592 
getVectorFromMat2x2ConstructorHelper(const Type & matrixType)1593 std::string MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
1594     SkASSERT(matrixType.isMatrix());
1595     SkASSERT(matrixType.rows() == 2);
1596     SkASSERT(matrixType.columns() == 2);
1597 
1598     std::string baseType = this->typeName(matrixType.componentType());
1599     std::string name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
1600     if (!fHelpers.contains(name)) {
1601         fHelpers.add(name);
1602 
1603         fExtraFunctions.printf(R"(
1604 %s4 %s(%s2x2 x) {
1605     return %s4(x[0].xy, x[1].xy);
1606 }
1607 )", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
1608     }
1609 
1610     return name;
1611 }
1612 
writeConstructorCompoundVector(const ConstructorCompound & c,Precedence parentPrecedence)1613 void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1614                                                         Precedence parentPrecedence) {
1615     SkASSERT(c.type().isVector());
1616 
1617     // Metal supports constructing vectors from a mix of scalars and vectors, but not matrices.
1618     // GLSL supports vec4(mat2x2), so we detect that case here and emit a helper function.
1619     if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1620         const Expression& expr = *c.argumentSpan().front();
1621         if (expr.type().isMatrix()) {
1622             this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
1623             this->write("(");
1624             this->writeExpression(expr, Precedence::kSequence);
1625             this->write(")");
1626             return;
1627         }
1628     }
1629 
1630     this->writeAnyConstructor(c, "(", ")", parentPrecedence);
1631 }
1632 
writeConstructorCompoundMatrix(const ConstructorCompound & c,Precedence parentPrecedence)1633 void MetalCodeGenerator::writeConstructorCompoundMatrix(const ConstructorCompound& c,
1634                                                         Precedence parentPrecedence) {
1635     SkASSERT(c.type().isMatrix());
1636 
1637     // Emit and invoke a matrix-constructor helper method if one is necessary.
1638     if (this->matrixConstructHelperIsNeeded(c)) {
1639         this->write(this->getMatrixConstructHelper(c));
1640         this->write("(");
1641         const char* separator = "";
1642         for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1643             this->write(separator);
1644             separator = ", ";
1645             this->writeExpression(*expr, Precedence::kSequence);
1646         }
1647         this->write(")");
1648         return;
1649     }
1650 
1651     // Metal doesn't allow creating matrices by passing in scalars and vectors in a jumble; it
1652     // requires your scalars to be grouped up into columns. Because `matrixConstructHelperIsNeeded`
1653     // returned false, we know that none of our scalars/vectors "wrap" across across a column, so we
1654     // can group our inputs up and synthesize a constructor for each column.
1655     const Type& matrixType = c.type();
1656     const Type& columnType = matrixType.columnType(fContext);
1657 
1658     this->writeType(matrixType);
1659     this->write("(");
1660     const char* separator = "";
1661     int scalarCount = 0;
1662     for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1663         this->write(separator);
1664         separator = ", ";
1665         if (arg->type().columns() < matrixType.rows()) {
1666             // Write a `floatN(` constructor to group scalars and smaller vectors together.
1667             if (!scalarCount) {
1668                 this->writeType(columnType);
1669                 this->write("(");
1670             }
1671             scalarCount += arg->type().columns();
1672         }
1673         this->writeExpression(*arg, Precedence::kSequence);
1674         if (scalarCount && scalarCount == matrixType.rows()) {
1675             // Close our `floatN(...` constructor block from above.
1676             this->write(")");
1677             scalarCount = 0;
1678         }
1679     }
1680     this->write(")");
1681 }
1682 
writeAnyConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1683 void MetalCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
1684                                              const char* leftBracket,
1685                                              const char* rightBracket,
1686                                              Precedence parentPrecedence) {
1687     this->writeType(c.type());
1688     this->write(leftBracket);
1689     const char* separator = "";
1690     for (const std::unique_ptr<Expression>& arg : c.argumentSpan()) {
1691         this->write(separator);
1692         separator = ", ";
1693         this->writeExpression(*arg, Precedence::kSequence);
1694     }
1695     this->write(rightBracket);
1696 }
1697 
writeCastConstructor(const AnyConstructor & c,const char * leftBracket,const char * rightBracket,Precedence parentPrecedence)1698 void MetalCodeGenerator::writeCastConstructor(const AnyConstructor& c,
1699                                               const char* leftBracket,
1700                                               const char* rightBracket,
1701                                               Precedence parentPrecedence) {
1702     return this->writeAnyConstructor(c, leftBracket, rightBracket, parentPrecedence);
1703 }
1704 
writeFragCoord()1705 void MetalCodeGenerator::writeFragCoord() {
1706     if (!fRTFlipName.empty()) {
1707         this->write("float4(_fragCoord.x, ");
1708         this->write(fRTFlipName.c_str());
1709         this->write(".x + ");
1710         this->write(fRTFlipName.c_str());
1711         this->write(".y * _fragCoord.y, 0.0, _fragCoord.w)");
1712     } else {
1713         this->write("float4(_fragCoord.x, _fragCoord.y, 0.0, _fragCoord.w)");
1714     }
1715 }
1716 
is_compute_builtin(const Variable & var)1717 static bool is_compute_builtin(const Variable& var) {
1718     switch (var.layout().fBuiltin) {
1719         case SK_NUMWORKGROUPS_BUILTIN:
1720         case SK_WORKGROUPID_BUILTIN:
1721         case SK_LOCALINVOCATIONID_BUILTIN:
1722         case SK_GLOBALINVOCATIONID_BUILTIN:
1723         case SK_LOCALINVOCATIONINDEX_BUILTIN:
1724             return true;
1725         default:
1726             break;
1727     }
1728     return false;
1729 }
1730 
1731 // true if the var is part of the Inputs struct
is_input(const Variable & var)1732 static bool is_input(const Variable& var) {
1733     SkASSERT(var.storage() == VariableStorage::kGlobal);
1734     return var.modifierFlags() & ModifierFlag::kIn &&
1735            (var.layout().fBuiltin == -1 || is_compute_builtin(var)) &&
1736            var.type().typeKind() != Type::TypeKind::kTexture;
1737 }
1738 
1739 // true if the var is part of the Outputs struct
is_output(const Variable & var)1740 static bool is_output(const Variable& var) {
1741     SkASSERT(var.storage() == VariableStorage::kGlobal);
1742     // inout vars get written into the Inputs struct, so we exclude them from Outputs
1743     return  (var.modifierFlags() & ModifierFlag::kOut) &&
1744            !(var.modifierFlags() & ModifierFlag::kIn) &&
1745              var.layout().fBuiltin == -1 &&
1746              var.type().typeKind() != Type::TypeKind::kTexture;
1747 }
1748 
1749 // true if the var is part of the Uniforms struct
is_uniforms(const Variable & var)1750 static bool is_uniforms(const Variable& var) {
1751     SkASSERT(var.storage() == VariableStorage::kGlobal);
1752     return var.modifierFlags().isUniform() &&
1753            var.type().typeKind() != Type::TypeKind::kSampler;
1754 }
1755 
1756 // true if the var is part of the Threadgroups struct
is_threadgroup(const Variable & var)1757 static bool is_threadgroup(const Variable& var) {
1758     SkASSERT(var.storage() == VariableStorage::kGlobal);
1759     return var.modifierFlags().isWorkgroup();
1760 }
1761 
1762 // true if the var is part of the Globals struct
is_in_globals(const Variable & var)1763 static bool is_in_globals(const Variable& var) {
1764     SkASSERT(var.storage() == VariableStorage::kGlobal);
1765     return !var.modifierFlags().isConst();
1766 }
1767 
writeVariableReference(const VariableReference & ref)1768 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
1769     switch (ref.variable()->layout().fBuiltin) {
1770         case SK_FRAGCOLOR_BUILTIN:
1771             this->write("_out.sk_FragColor");
1772             break;
1773         case SK_SAMPLEMASK_BUILTIN:
1774             this->write("_out.sk_SampleMask");
1775             break;
1776         case SK_SECONDARYFRAGCOLOR_BUILTIN:
1777             if (fCaps.fDualSourceBlendingSupport) {
1778                 this->write("_out.sk_SecondaryFragColor");
1779             } else {
1780                 fContext.fErrors->error(ref.position(), "'sk_SecondaryFragColor' not supported");
1781             }
1782             break;
1783         case SK_FRAGCOORD_BUILTIN:
1784             this->writeFragCoord();
1785             break;
1786         case SK_SAMPLEMASKIN_BUILTIN:
1787             this->write("sk_SampleMaskIn");
1788             break;
1789         case SK_VERTEXID_BUILTIN:
1790             this->write("sk_VertexID");
1791             break;
1792         case SK_INSTANCEID_BUILTIN:
1793             this->write("sk_InstanceID");
1794             break;
1795         case SK_CLOCKWISE_BUILTIN:
1796             // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
1797             // clockwise to match Skia convention.
1798             if (!fRTFlipName.empty()) {
1799                 this->write("(" + fRTFlipName + ".y < 0 ? _frontFacing : !_frontFacing)");
1800             } else {
1801                 this->write("_frontFacing");
1802             }
1803             break;
1804         case SK_LASTFRAGCOLOR_BUILTIN:
1805             if (fCaps.fFBFetchColorName) {
1806                 this->write(fCaps.fFBFetchColorName);
1807             } else {
1808                 fContext.fErrors->error(ref.position(), "'sk_LastFragColor' not supported");
1809             }
1810             break;
1811         default:
1812             const Variable& var = *ref.variable();
1813             if (var.storage() == Variable::Storage::kGlobal) {
1814                 if (is_input(var)) {
1815                     this->write("_in.");
1816                 } else if (is_output(var)) {
1817                     this->write("_out.");
1818                 } else if (is_uniforms(var)) {
1819                     this->write("_uniforms.");
1820                 } else if (is_threadgroup(var)) {
1821                     this->write("_threadgroups.");
1822                 } else if (is_in_globals(var)) {
1823                     this->write("_globals.");
1824                 }
1825             }
1826             this->writeName(var.mangledName());
1827     }
1828 }
1829 
writeIndexInnerExpression(const Expression & expr)1830 void MetalCodeGenerator::writeIndexInnerExpression(const Expression& expr) {
1831     if (fIndexSubstitutionData) {
1832         // If this expression already exists in the index-substitution map, use the substitute.
1833         if (const std::string* existing = fIndexSubstitutionData->fMap.find(&expr)) {
1834             this->write(*existing);
1835             return;
1836         }
1837 
1838         // If this expression is non-trivial, we will need to create a scratch variable and store
1839         // its value there.
1840         if (fIndexSubstitutionData->fCreateSubstitutes && !Analysis::IsTrivialExpression(expr)) {
1841             // Create a substitute variable and emit it into the main stream.
1842             std::string scratchVar = this->getTempVariable(expr.type());
1843             this->write(scratchVar);
1844 
1845             // Initialize the substitute variable in the prefix-stream.
1846             AutoOutputStream outputToPrefixStream(this, &fIndexSubstitutionData->fPrefixStream);
1847             this->write(scratchVar);
1848             this->write(" = ");
1849             this->writeExpression(expr, Precedence::kAssignment);
1850             this->write(", ");
1851 
1852             // Remember the substitute variable in our map.
1853             fIndexSubstitutionData->fMap.set(&expr, std::move(scratchVar));
1854             return;
1855         }
1856     }
1857 
1858     // We don't require index-substitution; just emit the expression normally.
1859     this->writeExpression(expr, Precedence::kExpression);
1860 }
1861 
writeIndexExpression(const IndexExpression & expr)1862 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
1863     // Metal does not seem to handle assignment into `vec.zyx[i]` properly--it compiles, but the
1864     // results are wrong. We rewrite the expression as `vec[uint3(2,1,0)[i]]` instead. (Filed with
1865     // Apple as FB12055941.)
1866     if (expr.base()->is<Swizzle>() && expr.base()->as<Swizzle>().components().size() > 1) {
1867         const Swizzle& swizzle = expr.base()->as<Swizzle>();
1868         this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1869         this->write("[uint" + std::to_string(swizzle.components().size()) + "(");
1870         auto separator = SkSL::String::Separator();
1871         for (int8_t component : swizzle.components()) {
1872             this->write(separator());
1873             this->write(std::to_string(component));
1874         }
1875         this->write(")[");
1876         this->writeIndexInnerExpression(*expr.index());
1877         this->write("]]");
1878     } else {
1879         this->writeExpression(*expr.base(), Precedence::kPostfix);
1880         this->write("[");
1881         this->writeIndexInnerExpression(*expr.index());
1882         this->write("]");
1883     }
1884 }
1885 
writeFieldAccess(const FieldAccess & f)1886 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
1887     const Field* field = &f.base()->type().fields()[f.fieldIndex()];
1888     if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
1889         this->writeExpression(*f.base(), Precedence::kPostfix);
1890         this->write(".");
1891     }
1892     switch (field->fLayout.fBuiltin) {
1893         case SK_POSITION_BUILTIN:
1894             this->write("_out.sk_Position");
1895             break;
1896         case SK_POINTSIZE_BUILTIN:
1897             this->write("_out.sk_PointSize");
1898             break;
1899         default:
1900             if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
1901                 this->write("_globals.");
1902                 this->write(fInterfaceBlockNameMap[&f.base()->type()]);
1903                 this->write("->");
1904             }
1905             this->writeName(field->fName);
1906     }
1907 }
1908 
writeSwizzle(const Swizzle & swizzle)1909 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1910     this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1911     this->write(".");
1912     this->write(Swizzle::MaskString(swizzle.components()));
1913 }
1914 
writeMatrixTimesEqualHelper(const Type & left,const Type & right,const Type & result)1915 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
1916                                                      const Type& result) {
1917     SkASSERT(left.isMatrix());
1918     SkASSERT(right.isMatrix());
1919     SkASSERT(result.isMatrix());
1920 
1921     std::string key = "Matrix *= " + this->typeName(left) + ":" + this->typeName(right);
1922 
1923     if (!fHelpers.contains(key)) {
1924         fHelpers.add(key);
1925         fExtraFunctions.printf("thread %s& operator*=(thread %s& left, thread const %s& right) {\n"
1926                                "    left = left * right;\n"
1927                                "    return left;\n"
1928                                "}\n",
1929                                this->typeName(result).c_str(), this->typeName(left).c_str(),
1930                                this->typeName(right).c_str());
1931     }
1932 }
1933 
writeMatrixEqualityHelpers(const Type & left,const Type & right)1934 void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
1935     SkASSERT(left.isMatrix());
1936     SkASSERT(right.isMatrix());
1937     SkASSERT(left.rows() == right.rows());
1938     SkASSERT(left.columns() == right.columns());
1939 
1940     std::string key = "Matrix == " + this->typeName(left) + ":" + this->typeName(right);
1941 
1942     if (!fHelpers.contains(key)) {
1943         fHelpers.add(key);
1944         fExtraFunctionPrototypes.printf(R"(
1945 thread bool operator==(const %s left, const %s right);
1946 thread bool operator!=(const %s left, const %s right);
1947 )",
1948                                         this->typeName(left).c_str(),
1949                                         this->typeName(right).c_str(),
1950                                         this->typeName(left).c_str(),
1951                                         this->typeName(right).c_str());
1952 
1953         fExtraFunctions.printf(
1954                 "thread bool operator==(const %s left, const %s right) {\n"
1955                 "    return ",
1956                 this->typeName(left).c_str(), this->typeName(right).c_str());
1957 
1958         const char* separator = "";
1959         for (int index=0; index<left.columns(); ++index) {
1960             fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
1961             separator = " &&\n           ";
1962         }
1963 
1964         fExtraFunctions.printf(
1965                 ";\n"
1966                 "}\n"
1967                 "thread bool operator!=(const %s left, const %s right) {\n"
1968                 "    return !(left == right);\n"
1969                 "}\n",
1970                 this->typeName(left).c_str(), this->typeName(right).c_str());
1971     }
1972 }
1973 
writeMatrixDivisionHelpers(const Type & type)1974 void MetalCodeGenerator::writeMatrixDivisionHelpers(const Type& type) {
1975     SkASSERT(type.isMatrix());
1976 
1977     std::string key = "Matrix / " + this->typeName(type);
1978 
1979     if (!fHelpers.contains(key)) {
1980         fHelpers.add(key);
1981         std::string typeName = this->typeName(type);
1982 
1983         fExtraFunctions.printf(
1984                 "thread %s operator/(const %s left, const %s right) {\n"
1985                 "    return %s(",
1986                 typeName.c_str(), typeName.c_str(), typeName.c_str(), typeName.c_str());
1987 
1988         const char* separator = "";
1989         for (int index=0; index<type.columns(); ++index) {
1990             fExtraFunctions.printf("%sleft[%d] / right[%d]", separator, index, index);
1991             separator = ", ";
1992         }
1993 
1994         fExtraFunctions.printf(");\n"
1995                                "}\n"
1996                                "thread %s& operator/=(thread %s& left, thread const %s& right) {\n"
1997                                "    left = left / right;\n"
1998                                "    return left;\n"
1999                                "}\n",
2000                                typeName.c_str(), typeName.c_str(), typeName.c_str());
2001     }
2002 }
2003 
writeArrayEqualityHelpers(const Type & type)2004 void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
2005     SkASSERT(type.isArray());
2006 
2007     // If the array's component type needs a helper as well, we need to emit that one first.
2008     this->writeEqualityHelpers(type.componentType(), type.componentType());
2009 
2010     std::string key = "ArrayEquality []";
2011     if (!fHelpers.contains(key)) {
2012         fHelpers.add(key);
2013         fExtraFunctionPrototypes.writeText(R"(
2014 template <typename T1, typename T2>
2015 bool operator==(const array_ref<T1> left, const array_ref<T2> right);
2016 template <typename T1, typename T2>
2017 bool operator!=(const array_ref<T1> left, const array_ref<T2> right);
2018 )");
2019         fExtraFunctions.writeText(R"(
2020 template <typename T1, typename T2>
2021 bool operator==(const array_ref<T1> left, const array_ref<T2> right) {
2022     if (left.size() != right.size()) {
2023         return false;
2024     }
2025     for (size_t index = 0; index < left.size(); ++index) {
2026         if (!all(left[index] == right[index])) {
2027             return false;
2028         }
2029     }
2030     return true;
2031 }
2032 
2033 template <typename T1, typename T2>
2034 bool operator!=(const array_ref<T1> left, const array_ref<T2> right) {
2035     return !(left == right);
2036 }
2037 )");
2038     }
2039 }
2040 
writeStructEqualityHelpers(const Type & type)2041 void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
2042     SkASSERT(type.isStruct());
2043     std::string key = "StructEquality " + this->typeName(type);
2044 
2045     if (!fHelpers.contains(key)) {
2046         fHelpers.add(key);
2047         // If one of the struct's fields needs a helper as well, we need to emit that one first.
2048         for (const Field& field : type.fields()) {
2049             this->writeEqualityHelpers(*field.fType, *field.fType);
2050         }
2051 
2052         // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
2053         // and GLSL but do not exist by default in Metal.
2054         fExtraFunctionPrototypes.printf(R"(
2055 thread bool operator==(thread const %s& left, thread const %s& right);
2056 thread bool operator!=(thread const %s& left, thread const %s& right);
2057 )",
2058                                         this->typeName(type).c_str(),
2059                                         this->typeName(type).c_str(),
2060                                         this->typeName(type).c_str(),
2061                                         this->typeName(type).c_str());
2062 
2063         fExtraFunctions.printf(
2064                 "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
2065                 "    return ",
2066                 this->typeName(type).c_str(),
2067                 this->typeName(type).c_str());
2068 
2069         const char* separator = "";
2070         for (const Field& field : type.fields()) {
2071             if (field.fType->isArray()) {
2072                 fExtraFunctions.printf(
2073                         "%s(make_array_ref(left.%.*s) == make_array_ref(right.%.*s))",
2074                         separator,
2075                         (int)field.fName.size(), field.fName.data(),
2076                         (int)field.fName.size(), field.fName.data());
2077             } else {
2078                 fExtraFunctions.printf("%sall(left.%.*s == right.%.*s)",
2079                                        separator,
2080                                        (int)field.fName.size(), field.fName.data(),
2081                                        (int)field.fName.size(), field.fName.data());
2082             }
2083             separator = " &&\n           ";
2084         }
2085         fExtraFunctions.printf(
2086                 ";\n"
2087                 "}\n"
2088                 "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
2089                 "    return !(left == right);\n"
2090                 "}\n",
2091                 this->typeName(type).c_str(),
2092                 this->typeName(type).c_str());
2093     }
2094 }
2095 
writeEqualityHelpers(const Type & leftType,const Type & rightType)2096 void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
2097     if (leftType.isArray() && rightType.isArray()) {
2098         this->writeArrayEqualityHelpers(leftType);
2099         return;
2100     }
2101     if (leftType.isStruct() && rightType.isStruct()) {
2102         this->writeStructEqualityHelpers(leftType);
2103         return;
2104     }
2105     if (leftType.isMatrix() && rightType.isMatrix()) {
2106         this->writeMatrixEqualityHelpers(leftType, rightType);
2107         return;
2108     }
2109 }
2110 
splatMatrixOf1(const Type & type)2111 std::string MetalCodeGenerator::splatMatrixOf1(const Type& type) {
2112     std::string str = this->typeName(type) + '(';
2113 
2114     auto separator = SkSL::String::Separator();
2115     for (int index = type.slotCount(); index--;) {
2116         str += separator();
2117         str += "1.0";
2118     }
2119 
2120     return str + ')';
2121 }
2122 
writeNumberAsMatrix(const Expression & expr,const Type & matrixType)2123 void MetalCodeGenerator::writeNumberAsMatrix(const Expression& expr, const Type& matrixType) {
2124     SkASSERT(expr.type().isNumber());
2125     SkASSERT(matrixType.isMatrix());
2126 
2127     // Componentwise multiply the scalar against a matrix of the desired size which contains all 1s.
2128     this->write("(");
2129     this->write(this->splatMatrixOf1(matrixType));
2130     this->write(" * ");
2131     this->writeExpression(expr, Precedence::kMultiplicative);
2132     this->write(")");
2133 }
2134 
writeBinaryExpressionElement(const Expression & expr,Operator op,const Expression & other,Precedence precedence)2135 void MetalCodeGenerator::writeBinaryExpressionElement(const Expression& expr,
2136                                                       Operator op,
2137                                                       const Expression& other,
2138                                                       Precedence precedence) {
2139     bool needMatrixSplatOnScalar = other.type().isMatrix() && expr.type().isNumber() &&
2140                                    op.isValidForMatrixOrVector() &&
2141                                    op.removeAssignment().kind() != Operator::Kind::STAR;
2142     if (needMatrixSplatOnScalar) {
2143         this->writeNumberAsMatrix(expr, other.type());
2144     } else if (op.isEquality() && expr.type().isArray()) {
2145         this->write("make_array_ref(");
2146         this->writeExpression(expr, precedence);
2147         this->write(")");
2148     } else {
2149         this->writeExpression(expr, precedence);
2150     }
2151 }
2152 
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)2153 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
2154                                                Precedence parentPrecedence) {
2155     const Expression& left = *b.left();
2156     const Expression& right = *b.right();
2157     const Type& leftType = left.type();
2158     const Type& rightType = right.type();
2159     Operator op = b.getOperator();
2160     Precedence precedence = op.getBinaryPrecedence();
2161     bool needParens = precedence >= parentPrecedence;
2162     switch (op.kind()) {
2163         case Operator::Kind::EQEQ:
2164             this->writeEqualityHelpers(leftType, rightType);
2165             if (leftType.isVector()) {
2166                 this->write("all");
2167                 needParens = true;
2168             }
2169             break;
2170         case Operator::Kind::NEQ:
2171             this->writeEqualityHelpers(leftType, rightType);
2172             if (leftType.isVector()) {
2173                 this->write("any");
2174                 needParens = true;
2175             }
2176             break;
2177         default:
2178             break;
2179     }
2180     if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Operator::Kind::STAREQ) {
2181         this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
2182     }
2183     if (op.removeAssignment().kind() == Operator::Kind::SLASH &&
2184         ((leftType.isMatrix() && rightType.isMatrix()) ||
2185          (leftType.isScalar() && rightType.isMatrix()) ||
2186          (leftType.isMatrix() && rightType.isScalar()))) {
2187         this->writeMatrixDivisionHelpers(leftType.isMatrix() ? leftType : rightType);
2188     }
2189 
2190     if (needParens) {
2191         this->write("(");
2192     }
2193 
2194     // Some expressions need to be rewritten from `lhs *= rhs` to `lhs = lhs * rhs`, e.g.:
2195     //     float4 x = float4(1);
2196     //     x.xy *= float2x2(...);
2197     // will report the error "non-const reference cannot bind to vector element."
2198     if (op.isCompoundAssignment() && left.kind() == Expression::Kind::kSwizzle) {
2199         // We need to do the rewrite. This could be dangerous if the lhs contains an index
2200         // expression with a side effect (such as `array[Func()]`), so we enable index-substitution
2201         // here for the LHS; any index-expression with side effects will be evaluated into a scratch
2202         // variable.
2203         this->writeWithIndexSubstitution([&] {
2204             this->writeExpression(left, precedence);
2205             this->write(" = ");
2206             this->writeExpression(left, Precedence::kAssignment);
2207             this->write(operator_name(op.removeAssignment()));
2208 
2209             // We never want to create index-expression substitutes on the RHS of the expression;
2210             // the RHS is only emitted one time.
2211             fIndexSubstitutionData->fCreateSubstitutes = false;
2212 
2213             this->writeBinaryExpressionElement(right, op, left,
2214                                                op.removeAssignment().getBinaryPrecedence());
2215         });
2216     } else {
2217         // We don't need any rewrite; emit the binary expression as-is.
2218         this->writeBinaryExpressionElement(left, op, right, precedence);
2219         this->write(operator_name(op));
2220         this->writeBinaryExpressionElement(right, op, left, precedence);
2221     }
2222 
2223     if (needParens) {
2224         this->write(")");
2225     }
2226 }
2227 
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)2228 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
2229                                                Precedence parentPrecedence) {
2230     if (Precedence::kTernary >= parentPrecedence) {
2231         this->write("(");
2232     }
2233     this->writeExpression(*t.test(), Precedence::kTernary);
2234     this->write(" ? ");
2235     this->writeExpression(*t.ifTrue(), Precedence::kTernary);
2236     this->write(" : ");
2237     this->writeExpression(*t.ifFalse(), Precedence::kTernary);
2238     if (Precedence::kTernary >= parentPrecedence) {
2239         this->write(")");
2240     }
2241 }
2242 
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)2243 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
2244                                                Precedence parentPrecedence) {
2245     const Operator op = p.getOperator();
2246     switch (op.kind()) {
2247         case Operator::Kind::PLUS:
2248             // According to the MSL specification, the arithmetic unary operators (+ and –) do not
2249             // act upon matrix-typed operands. We treat the unary "+" as a no-op for all operands.
2250             this->writeExpression(*p.operand(), Precedence::kPrefix);
2251             return;
2252 
2253         case Operator::Kind::MINUS:
2254             // Transform the unary `-` on a matrix type to a multiplication by -1.
2255             if (p.operand()->type().isMatrix()) {
2256                 this->write(p.type().componentType().highPrecision() ? "(-1.0 * "
2257                                                                      : "(-1.0h * ");
2258                 this->writeExpression(*p.operand(), Precedence::kMultiplicative);
2259                 this->write(")");
2260                 return;
2261             }
2262             break;
2263 
2264         case Operator::Kind::PLUSPLUS:
2265         case Operator::Kind::MINUSMINUS:
2266             if (p.operand()->type().isMatrix()) {
2267                 // Transform `++x` or `--x` on a matrix type to `mat += T(1.0, ...)` or
2268                 // `mat -= T(1.0, ...)`.
2269                 this->write("(");
2270                 this->writeExpression(*p.operand(), Precedence::kAssignment);
2271                 this->write(op.kind() == Operator::Kind::PLUSPLUS ? " += " : " -= ");
2272                 this->write(this->splatMatrixOf1(p.operand()->type()));
2273                 this->write(")");
2274                 return;
2275             }
2276             break;
2277 
2278         default:
2279             break;
2280     }
2281 
2282     if (Precedence::kPrefix >= parentPrecedence) {
2283         this->write("(");
2284     }
2285 
2286     this->write(op.tightOperatorName());
2287     this->writeExpression(*p.operand(), Precedence::kPrefix);
2288 
2289     if (Precedence::kPrefix >= parentPrecedence) {
2290         this->write(")");
2291     }
2292 }
2293 
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)2294 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
2295                                                 Precedence parentPrecedence) {
2296     const Operator op = p.getOperator();
2297     switch (op.kind()) {
2298         case Operator::Kind::PLUSPLUS:
2299         case Operator::Kind::MINUSMINUS:
2300             if (p.operand()->type().isMatrix()) {
2301                 // We need to transform `x++` or `x--` into `+=` and `-=` on a matrix.
2302                 // Unfortunately, that requires making a temporary copy of the old value and
2303                 // emitting a sequence expression: `((temp = mat), (mat += T(1.0, ...)), temp)`.
2304                 std::string tempMatrix = this->getTempVariable(p.operand()->type());
2305                 this->write("((");
2306                 this->write(tempMatrix);
2307                 this->write(" = ");
2308                 this->writeExpression(*p.operand(), Precedence::kAssignment);
2309                 this->write("), (");
2310                 this->writeExpression(*p.operand(), Precedence::kAssignment);
2311                 this->write(op.kind() == Operator::Kind::PLUSPLUS ? " += " : " -= ");
2312                 this->write(this->splatMatrixOf1(p.operand()->type()));
2313                 this->write("), ");
2314                 this->write(tempMatrix);
2315                 this->write(")");
2316                 return;
2317             }
2318             break;
2319 
2320         default:
2321             break;
2322     }
2323 
2324     if (Precedence::kPostfix >= parentPrecedence) {
2325         this->write("(");
2326     }
2327     this->writeExpression(*p.operand(), Precedence::kPostfix);
2328     this->write(op.tightOperatorName());
2329     if (Precedence::kPostfix >= parentPrecedence) {
2330         this->write(")");
2331     }
2332 }
2333 
writeLiteral(const Literal & l)2334 void MetalCodeGenerator::writeLiteral(const Literal& l) {
2335     const Type& type = l.type();
2336     if (type.isFloat()) {
2337         this->write(l.description(OperatorPrecedence::kExpression));
2338         if (!l.type().highPrecision()) {
2339             this->write("h");
2340         }
2341         return;
2342     }
2343     if (type.isInteger()) {
2344         if (type.matches(*fContext.fTypes.fUInt)) {
2345             this->write(std::to_string(l.intValue() & 0xffffffff));
2346             this->write("u");
2347         } else if (type.matches(*fContext.fTypes.fUShort)) {
2348             this->write(std::to_string(l.intValue() & 0xffff));
2349             this->write("u");
2350         } else {
2351             this->write(std::to_string(l.intValue()));
2352         }
2353         return;
2354     }
2355     SkASSERT(type.isBoolean());
2356     this->write(l.description(OperatorPrecedence::kExpression));
2357 }
2358 
writeFunctionRequirementArgs(const FunctionDeclaration & f,const char * & separator)2359 void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
2360                                                       const char*& separator) {
2361     Requirements requirements = this->requirements(f);
2362     if (requirements & kInputs_Requirement) {
2363         this->write(separator);
2364         this->write("_in");
2365         separator = ", ";
2366     }
2367     if (requirements & kOutputs_Requirement) {
2368         this->write(separator);
2369         this->write("_out");
2370         separator = ", ";
2371     }
2372     if (requirements & kUniforms_Requirement) {
2373         this->write(separator);
2374         this->write("_uniforms");
2375         separator = ", ";
2376     }
2377     if (requirements & kGlobals_Requirement) {
2378         this->write(separator);
2379         this->write("_globals");
2380         separator = ", ";
2381     }
2382     if (requirements & kFragCoord_Requirement) {
2383         this->write(separator);
2384         this->write("_fragCoord");
2385         separator = ", ";
2386     }
2387     if (requirements & kSampleMaskIn_Requirement) {
2388         this->write(separator);
2389         this->write("sk_SampleMaskIn");
2390         separator = ", ";
2391     }
2392     if (requirements & kVertexID_Requirement) {
2393         this->write(separator);
2394         this->write("sk_VertexID");
2395         separator = ", ";
2396     }
2397     if (requirements & kInstanceID_Requirement) {
2398         this->write(separator);
2399         this->write("sk_InstanceID");
2400         separator = ", ";
2401     }
2402     if (requirements & kThreadgroups_Requirement) {
2403         this->write(separator);
2404         this->write("_threadgroups");
2405         separator = ", ";
2406     }
2407 }
2408 
writeFunctionRequirementParams(const FunctionDeclaration & f,const char * & separator)2409 void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
2410                                                         const char*& separator) {
2411     Requirements requirements = this->requirements(f);
2412     if (requirements & kInputs_Requirement) {
2413         this->write(separator);
2414         this->write("Inputs _in");
2415         separator = ", ";
2416     }
2417     if (requirements & kOutputs_Requirement) {
2418         this->write(separator);
2419         this->write("thread Outputs& _out");
2420         separator = ", ";
2421     }
2422     if (requirements & kUniforms_Requirement) {
2423         this->write(separator);
2424         this->write("Uniforms _uniforms");
2425         separator = ", ";
2426     }
2427     if (requirements & kGlobals_Requirement) {
2428         this->write(separator);
2429         this->write("thread Globals& _globals");
2430         separator = ", ";
2431     }
2432     if (requirements & kFragCoord_Requirement) {
2433         this->write(separator);
2434         this->write("float4 _fragCoord");
2435         separator = ", ";
2436     }
2437     if (requirements & kSampleMaskIn_Requirement) {
2438         this->write(separator);
2439         this->write("uint sk_SampleMaskIn");
2440         separator = ", ";
2441     }
2442     if (requirements & kVertexID_Requirement) {
2443         this->write(separator);
2444         this->write("uint sk_VertexID");
2445         separator = ", ";
2446     }
2447     if (requirements & kInstanceID_Requirement) {
2448         this->write(separator);
2449         this->write("uint sk_InstanceID");
2450         separator = ", ";
2451     }
2452     if (requirements & kThreadgroups_Requirement) {
2453         this->write(separator);
2454         this->write("threadgroup Threadgroups& _threadgroups");
2455         separator = ", ";
2456     }
2457 }
2458 
getUniformBinding(const Layout & layout)2459 int MetalCodeGenerator::getUniformBinding(const Layout& layout) {
2460     return (layout.fBinding >= 0) ? layout.fBinding
2461                                   : fProgram.fConfig->fSettings.fDefaultUniformBinding;
2462 }
2463 
getUniformSet(const Layout & layout)2464 int MetalCodeGenerator::getUniformSet(const Layout& layout) {
2465     return (layout.fSet >= 0) ? layout.fSet
2466                               : fProgram.fConfig->fSettings.fDefaultUniformSet;
2467 }
2468 
writeFunctionDeclaration(const FunctionDeclaration & f)2469 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
2470     fRTFlipName = (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None)
2471                           ? "_globals._anonInterface0->" SKSL_RTFLIP_NAME
2472                           : "";
2473     const char* separator = "";
2474     if (f.isMain()) {
2475         if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2476             this->write("fragment Outputs fragmentMain(");
2477         } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2478             this->write("vertex Outputs vertexMain(");
2479         } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2480             this->write("kernel void computeMain(");
2481         } else {
2482             fContext.fErrors->error(Position(), "unsupported kind of program");
2483             return false;
2484         }
2485         if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2486             this->write("Inputs _in [[stage_in]]");
2487             separator = ", ";
2488         }
2489         if (-1 != fUniformBuffer) {
2490             this->write(separator);
2491             this->write("constant Uniforms& _uniforms [[buffer(" +
2492                         std::to_string(fUniformBuffer) + ")]]");
2493             separator = ", ";
2494         }
2495         for (const ProgramElement* e : fProgram.elements()) {
2496             if (e->is<GlobalVarDeclaration>()) {
2497                 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2498                 const VarDeclaration& decl = decls.varDeclaration();
2499                 const Variable* var = decl.var();
2500                 const SkSL::Type::TypeKind varKind = var->type().typeKind();
2501 
2502                 if (varKind == Type::TypeKind::kSampler || varKind == Type::TypeKind::kTexture) {
2503                     if (var->type().dimensions() != SpvDim2D) {
2504                         // Not yet implemented--Skia currently only uses 2D textures.
2505                         fContext.fErrors->error(decls.fPosition, "Unsupported texture dimensions");
2506                         return false;
2507                     }
2508 
2509                     int binding = getUniformBinding(var->layout());
2510                     this->write(separator);
2511                     separator = ", ";
2512 
2513                     if (varKind == Type::TypeKind::kSampler) {
2514                         this->writeType(var->type().textureType());
2515                         this->write(" ");
2516                         this->writeName(var->mangledName());
2517                         this->write(kTextureSuffix);
2518                         this->write(" [[texture(");
2519                         this->write(std::to_string(binding));
2520                         this->write(")]], sampler ");
2521                         this->writeName(var->mangledName());
2522                         this->write(kSamplerSuffix);
2523                         this->write(" [[sampler(");
2524                         this->write(std::to_string(binding));
2525                         this->write(")]]");
2526                     } else {
2527                         SkASSERT(varKind == Type::TypeKind::kTexture);
2528                         this->writeType(var->type());
2529                         this->write(" ");
2530                         this->writeName(var->mangledName());
2531                         this->write(" [[texture(");
2532                         this->write(std::to_string(binding));
2533                         this->write(")]]");
2534                     }
2535                 } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2536                     std::string_view attr;
2537                     switch (var->layout().fBuiltin) {
2538                         case SK_NUMWORKGROUPS_BUILTIN:
2539                             attr = " [[threadgroups_per_grid]]";
2540                             break;
2541                         case SK_WORKGROUPID_BUILTIN:
2542                             attr = " [[threadgroup_position_in_grid]]";
2543                             break;
2544                         case SK_LOCALINVOCATIONID_BUILTIN:
2545                             attr = " [[thread_position_in_threadgroup]]";
2546                             break;
2547                         case SK_GLOBALINVOCATIONID_BUILTIN:
2548                             attr = " [[thread_position_in_grid]]";
2549                             break;
2550                         case SK_LOCALINVOCATIONINDEX_BUILTIN:
2551                             attr = " [[thread_index_in_threadgroup]]";
2552                             break;
2553                         default:
2554                             break;
2555                     }
2556                     if (!attr.empty()) {
2557                         this->write(separator);
2558                         this->writeType(var->type());
2559                         this->write(" ");
2560                         this->write(var->name());
2561                         this->write(attr);
2562                         separator = ", ";
2563                     }
2564                 }
2565             } else if (e->is<InterfaceBlock>()) {
2566                 const InterfaceBlock& intf = e->as<InterfaceBlock>();
2567                 if (intf.typeName() == "sk_PerVertex") {
2568                     continue;
2569                 }
2570                 this->write(separator);
2571                 if (is_readonly(intf)) {
2572                     this->write("const ");
2573                 }
2574                 this->write(is_buffer(intf) ? "device " : "constant ");
2575                 this->writeType(intf.var()->type());
2576                 this->write("& " );
2577                 this->write(fInterfaceBlockNameMap[&intf.var()->type()]);
2578                 this->write(" [[buffer(");
2579                 this->write(std::to_string(this->getUniformBinding(intf.var()->layout())));
2580                 this->write(")]]");
2581                 separator = ", ";
2582             }
2583         }
2584         if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2585             if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None &&
2586                 fInterfaceBlockNameMap.empty()) {
2587                 this->write(separator);
2588                 this->write("constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
2589                 fRTFlipName = "_anonInterface0." SKSL_RTFLIP_NAME;
2590                 separator = ", ";
2591             }
2592             this->write(separator);
2593             this->write("bool _frontFacing [[front_facing]], float4 _fragCoord [[position]]");
2594             if (this->requirements(f) & kSampleMaskIn_Requirement) {
2595                 this->write(", uint sk_SampleMaskIn [[sample_mask]]");
2596             }
2597             if (fProgram.fInterface.fUseLastFragColor && fCaps.fFBFetchColorName) {
2598                 this->write(", half4 " + std::string(fCaps.fFBFetchColorName) +
2599                             " [[color(0)]]\n");
2600             }
2601             separator = ", ";
2602         } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2603             this->write(separator);
2604             this->write("uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
2605             separator = ", ";
2606         }
2607     } else {
2608         this->writeType(f.returnType());
2609         this->write(" ");
2610         this->writeName(f.mangledName());
2611         this->write("(");
2612         this->writeFunctionRequirementParams(f, separator);
2613     }
2614     for (const Variable* param : f.parameters()) {
2615         // This is a workaround for our test files. They use the runtime effect signature, so main
2616         // takes a coords parameter. We detect these at IR generation time, and we omit them from
2617         // the declaration here, so the function is valid Metal. (Well, valid as long as the
2618         // coordinates aren't actually referenced.)
2619         if (f.isMain() && param == f.getMainCoordsParameter()) {
2620             continue;
2621         }
2622         this->write(separator);
2623         separator = ", ";
2624         this->writeModifiers(param->modifierFlags());
2625         this->writeType(param->type());
2626         if (pass_by_reference(param->type(), param->modifierFlags())) {
2627             this->write("&");
2628         }
2629         this->write(" ");
2630         this->writeName(param->mangledName());
2631     }
2632     this->write(")");
2633     return true;
2634 }
2635 
writeFunctionPrototype(const FunctionPrototype & f)2636 void MetalCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
2637     this->writeFunctionDeclaration(f.declaration());
2638     this->writeLine(";");
2639 }
2640 
is_block_ending_with_return(const Statement * stmt)2641 static bool is_block_ending_with_return(const Statement* stmt) {
2642     // This function detects (potentially nested) blocks that end in a return statement.
2643     if (!stmt->is<Block>()) {
2644         return false;
2645     }
2646     const StatementArray& block = stmt->as<Block>().children();
2647     for (int index = block.size(); index--; ) {
2648         stmt = block[index].get();
2649         if (stmt->is<ReturnStatement>()) {
2650             return true;
2651         }
2652         if (stmt->is<Block>()) {
2653             return is_block_ending_with_return(stmt);
2654         }
2655         if (!stmt->is<Nop>()) {
2656             break;
2657         }
2658     }
2659     return false;
2660 }
2661 
writeComputeMainInputs()2662 void MetalCodeGenerator::writeComputeMainInputs() {
2663     // Compute shaders only have input variables (e.g. sk_GlobalInvocationID) and access program
2664     // inputs/outputs via the Globals and Uniforms structs. We collect the allowed "in" parameters
2665     // into an Input struct here, since the rest of the code expects the normal _in / _out pattern.
2666     this->write("Inputs _in = { ");
2667     const char* separator = "";
2668     for (const ProgramElement* e : fProgram.elements()) {
2669         if (e->is<GlobalVarDeclaration>()) {
2670             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2671             const Variable* var = decls.varDeclaration().var();
2672             if (is_input(*var)) {
2673                 this->write(separator);
2674                 separator = ", ";
2675                 this->writeName(var->mangledName());
2676             }
2677         }
2678     }
2679     this->writeLine(" };");
2680 }
2681 
writeFunction(const FunctionDefinition & f)2682 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
2683     SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
2684 
2685     if (!this->writeFunctionDeclaration(f.declaration())) {
2686         return;
2687     }
2688 
2689     fCurrentFunction = &f.declaration();
2690     SkScopeExit clearCurrentFunction([&] { fCurrentFunction = nullptr; });
2691 
2692     this->writeLine(" {");
2693 
2694     if (f.declaration().isMain()) {
2695         fIndentation++;
2696         this->writeGlobalInit();
2697         if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
2698             this->writeThreadgroupInit();
2699             this->writeComputeMainInputs();
2700         }
2701         else {
2702             this->writeLine("Outputs _out;");
2703             this->writeLine("(void)_out;");
2704         }
2705         fIndentation--;
2706     }
2707 
2708     fFunctionHeader.clear();
2709     StringStream buffer;
2710     {
2711         AutoOutputStream outputToBuffer(this, &buffer);
2712         fIndentation++;
2713         for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
2714             if (!stmt->isEmpty()) {
2715                 this->writeStatement(*stmt);
2716                 this->finishLine();
2717             }
2718         }
2719         if (f.declaration().isMain()) {
2720             // If the main function doesn't end with a return, we need to synthesize one here.
2721             if (!is_block_ending_with_return(f.body().get())) {
2722                 this->writeReturnStatementFromMain();
2723                 this->finishLine();
2724             }
2725         }
2726         fIndentation--;
2727         this->writeLine("}");
2728     }
2729     this->write(fFunctionHeader);
2730     this->write(buffer.str());
2731 }
2732 
writeModifiers(ModifierFlags flags)2733 void MetalCodeGenerator::writeModifiers(ModifierFlags flags) {
2734     if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
2735         (flags & (ModifierFlag::kIn | ModifierFlag::kOut))) {
2736         this->write("device ");
2737     } else if (flags & ModifierFlag::kOut) {
2738         this->write("thread ");
2739     }
2740     if (flags.isConst()) {
2741         this->write("const ");
2742     }
2743 }
2744 
writeInterfaceBlock(const InterfaceBlock & intf)2745 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2746     if (intf.typeName() == "sk_PerVertex") {
2747         return;
2748     }
2749     const Type* structType = &intf.var()->type().componentType();
2750     this->writeModifiers(intf.var()->modifierFlags());
2751     this->write("struct ");
2752     this->writeType(*structType);
2753     this->writeLine(" {");
2754     fIndentation++;
2755     this->writeFields(structType->fields(), structType->fPosition);
2756     if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None) {
2757         this->writeLine("float2 " SKSL_RTFLIP_NAME ";");
2758     }
2759     fIndentation--;
2760     this->write("}");
2761     if (!intf.instanceName().empty()) {
2762         this->write(" ");
2763         this->write(intf.instanceName());
2764         if (intf.arraySize() > 0) {
2765             this->write("[");
2766             this->write(std::to_string(intf.arraySize()));
2767             this->write("]");
2768         }
2769         fInterfaceBlockNameMap.set(&intf.var()->type(), std::string(intf.instanceName()));
2770     } else {
2771         fInterfaceBlockNameMap.set(&intf.var()->type(),
2772                                    "_anonInterface" + std::to_string(fAnonInterfaceCount++));
2773     }
2774     this->writeLine(";");
2775 }
2776 
writeFields(SkSpan<const Field> fields,Position parentPos)2777 void MetalCodeGenerator::writeFields(SkSpan<const Field> fields, Position parentPos) {
2778     MemoryLayout memoryLayout(MemoryLayout::Standard::kMetal);
2779     int currentOffset = 0;
2780     for (const Field& field : fields) {
2781         int fieldOffset = field.fLayout.fOffset;
2782         const Type* fieldType = field.fType;
2783         if (!memoryLayout.isSupported(*fieldType)) {
2784             fContext.fErrors->error(parentPos, "type '" + std::string(fieldType->name()) +
2785                                                 "' is not permitted here");
2786             return;
2787         }
2788         if (fieldOffset != -1) {
2789             if (currentOffset > fieldOffset) {
2790                 fContext.fErrors->error(field.fPosition,
2791                                         "offset of field '" + std::string(field.fName) +
2792                                         "' must be at least " + std::to_string(currentOffset));
2793                 return;
2794             } else if (currentOffset < fieldOffset) {
2795                 this->write("char pad");
2796                 this->write(std::to_string(fPaddingCount++));
2797                 this->write("[");
2798                 this->write(std::to_string(fieldOffset - currentOffset));
2799                 this->writeLine("];");
2800                 currentOffset = fieldOffset;
2801             }
2802             int alignment = memoryLayout.alignment(*fieldType);
2803             if (fieldOffset % alignment) {
2804                 fContext.fErrors->error(field.fPosition,
2805                                         "offset of field '" + std::string(field.fName) +
2806                                         "' must be a multiple of " + std::to_string(alignment));
2807                 return;
2808             }
2809         }
2810         if (fieldType->isUnsizedArray()) {
2811             // An unsized array always appears as the last member of a storage block. We declare
2812             // it as a one-element array and allow dereferencing past the capacity.
2813             // TODO(armansito): This is because C++ does not support flexible array members like C99
2814             // does. This generally works but it can lead to UB as compilers are free to insert
2815             // padding past the first element of the array. An alternative approach is to declare
2816             // the struct without the unsized array member and replace variable references with a
2817             // buffer offset calculation based on sizeof().
2818             this->writeModifiers(field.fModifierFlags);
2819             this->writeType(fieldType->componentType());
2820             this->write(" ");
2821             this->writeName(field.fName);
2822             this->write("[1]");
2823         } else {
2824             size_t fieldSize = memoryLayout.size(*fieldType);
2825             if (fieldSize > static_cast<size_t>(std::numeric_limits<int>::max() - currentOffset)) {
2826                 fContext.fErrors->error(parentPos, "field offset overflow");
2827                 return;
2828             }
2829             currentOffset += fieldSize;
2830             this->writeModifiers(field.fModifierFlags);
2831             this->writeType(*fieldType);
2832             this->write(" ");
2833             this->writeName(field.fName);
2834         }
2835         this->writeLine(";");
2836     }
2837 }
2838 
writeVarInitializer(const Variable & var,const Expression & value)2839 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
2840     this->writeExpression(value, Precedence::kExpression);
2841 }
2842 
writeName(std::string_view name)2843 void MetalCodeGenerator::writeName(std::string_view name) {
2844     if (fReservedWords.contains(name)) {
2845         this->write("_"); // adding underscore before name to avoid conflict with reserved words
2846     }
2847     this->write(name);
2848 }
2849 
writeVarDeclaration(const VarDeclaration & varDecl)2850 void MetalCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2851     this->writeModifiers(varDecl.var()->modifierFlags());
2852     this->writeType(varDecl.var()->type());
2853     this->write(" ");
2854     this->writeName(varDecl.var()->mangledName());
2855     if (varDecl.value()) {
2856         this->write(" = ");
2857         this->writeVarInitializer(*varDecl.var(), *varDecl.value());
2858     }
2859     this->write(";");
2860 }
2861 
writeStatement(const Statement & s)2862 void MetalCodeGenerator::writeStatement(const Statement& s) {
2863     switch (s.kind()) {
2864         case Statement::Kind::kBlock:
2865             this->writeBlock(s.as<Block>());
2866             break;
2867         case Statement::Kind::kExpression:
2868             this->writeExpressionStatement(s.as<ExpressionStatement>());
2869             break;
2870         case Statement::Kind::kReturn:
2871             this->writeReturnStatement(s.as<ReturnStatement>());
2872             break;
2873         case Statement::Kind::kVarDeclaration:
2874             this->writeVarDeclaration(s.as<VarDeclaration>());
2875             break;
2876         case Statement::Kind::kIf:
2877             this->writeIfStatement(s.as<IfStatement>());
2878             break;
2879         case Statement::Kind::kFor:
2880             this->writeForStatement(s.as<ForStatement>());
2881             break;
2882         case Statement::Kind::kDo:
2883             this->writeDoStatement(s.as<DoStatement>());
2884             break;
2885         case Statement::Kind::kSwitch:
2886             this->writeSwitchStatement(s.as<SwitchStatement>());
2887             break;
2888         case Statement::Kind::kBreak:
2889             this->write("break;");
2890             break;
2891         case Statement::Kind::kContinue:
2892             this->write("continue;");
2893             break;
2894         case Statement::Kind::kDiscard:
2895             this->write("discard_fragment();");
2896             break;
2897         case Statement::Kind::kNop:
2898             this->write(";");
2899             break;
2900         default:
2901             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
2902             break;
2903     }
2904 }
2905 
writeBlock(const Block & b)2906 void MetalCodeGenerator::writeBlock(const Block& b) {
2907     // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
2908     // something here to make the code valid).
2909     bool isScope = b.isScope() || b.isEmpty();
2910     if (isScope) {
2911         this->writeLine("{");
2912         fIndentation++;
2913     }
2914     for (const std::unique_ptr<Statement>& stmt : b.children()) {
2915         if (!stmt->isEmpty()) {
2916             this->writeStatement(*stmt);
2917             this->finishLine();
2918         }
2919     }
2920     if (isScope) {
2921         fIndentation--;
2922         this->write("}");
2923     }
2924 }
2925 
writeIfStatement(const IfStatement & stmt)2926 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
2927     this->write("if (");
2928     this->writeExpression(*stmt.test(), Precedence::kExpression);
2929     this->write(") ");
2930     this->writeStatement(*stmt.ifTrue());
2931     if (stmt.ifFalse()) {
2932         this->write(" else ");
2933         this->writeStatement(*stmt.ifFalse());
2934     }
2935 }
2936 
writeForStatement(const ForStatement & f)2937 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
2938     // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
2939     if (!f.initializer() && f.test() && !f.next()) {
2940         this->write("while (");
2941         this->writeExpression(*f.test(), Precedence::kExpression);
2942         this->write(") ");
2943         this->writeStatement(*f.statement());
2944         return;
2945     }
2946 
2947     this->write("for (");
2948     if (f.initializer() && !f.initializer()->isEmpty()) {
2949         this->writeStatement(*f.initializer());
2950     } else {
2951         this->write("; ");
2952     }
2953     if (f.test()) {
2954         this->writeExpression(*f.test(), Precedence::kExpression);
2955     }
2956     this->write("; ");
2957     if (f.next()) {
2958         this->writeExpression(*f.next(), Precedence::kExpression);
2959     }
2960     this->write(") ");
2961     this->writeStatement(*f.statement());
2962 }
2963 
writeDoStatement(const DoStatement & d)2964 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
2965     this->write("do ");
2966     this->writeStatement(*d.statement());
2967     this->write(" while (");
2968     this->writeExpression(*d.test(), Precedence::kExpression);
2969     this->write(");");
2970 }
2971 
writeExpressionStatement(const ExpressionStatement & s)2972 void MetalCodeGenerator::writeExpressionStatement(const ExpressionStatement& s) {
2973     if (fProgram.fConfig->fSettings.fOptimize && !Analysis::HasSideEffects(*s.expression())) {
2974         // Don't emit dead expressions.
2975         return;
2976     }
2977     this->writeExpression(*s.expression(), Precedence::kStatement);
2978     this->write(";");
2979 }
2980 
writeSwitchStatement(const SwitchStatement & s)2981 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2982     this->write("switch (");
2983     this->writeExpression(*s.value(), Precedence::kExpression);
2984     this->writeLine(") {");
2985     fIndentation++;
2986     for (const std::unique_ptr<Statement>& stmt : s.cases()) {
2987         const SwitchCase& c = stmt->as<SwitchCase>();
2988         if (c.isDefault()) {
2989             this->writeLine("default:");
2990         } else {
2991             this->write("case ");
2992             this->write(std::to_string(c.value()));
2993             this->writeLine(":");
2994         }
2995         if (!c.statement()->isEmpty()) {
2996             fIndentation++;
2997             this->writeStatement(*c.statement());
2998             this->finishLine();
2999             fIndentation--;
3000         }
3001     }
3002     fIndentation--;
3003     this->write("}");
3004 }
3005 
writeReturnStatementFromMain()3006 void MetalCodeGenerator::writeReturnStatementFromMain() {
3007     // main functions in Metal return a magic _out parameter that doesn't exist in SkSL.
3008     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) ||
3009         ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3010         this->write("return _out;");
3011     } else if (ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
3012         this->write("return;");
3013     } else {
3014         SkDEBUGFAIL("unsupported kind of program");
3015     }
3016 }
3017 
writeReturnStatement(const ReturnStatement & r)3018 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
3019     if (fCurrentFunction && fCurrentFunction->isMain()) {
3020         if (r.expression()) {
3021             if (r.expression()->type().matches(*fContext.fTypes.fHalf4)) {
3022                 this->write("_out.sk_FragColor = ");
3023                 this->writeExpression(*r.expression(), Precedence::kExpression);
3024                 this->writeLine(";");
3025             } else {
3026                 fContext.fErrors->error(r.fPosition,
3027                         "Metal does not support returning '" +
3028                         r.expression()->type().description() + "' from main()");
3029             }
3030         }
3031         this->writeReturnStatementFromMain();
3032         return;
3033     }
3034 
3035     this->write("return");
3036     if (r.expression()) {
3037         this->write(" ");
3038         this->writeExpression(*r.expression(), Precedence::kExpression);
3039     }
3040     this->write(";");
3041 }
3042 
writeHeader()3043 void MetalCodeGenerator::writeHeader() {
3044     this->writeLine("#include <metal_stdlib>");
3045     this->writeLine("#include <simd/simd.h>");
3046     this->writeLine("#ifdef __clang__");
3047     this->writeLine("#pragma clang diagnostic ignored \"-Wall\"");
3048     this->writeLine("#endif");
3049     this->writeLine("using namespace metal;");
3050 }
3051 
writeSampler2DPolyfill()3052 void MetalCodeGenerator::writeSampler2DPolyfill() {
3053     class : public GlobalStructVisitor {
3054     public:
3055         void visitSampler(const Type&, std::string_view) override {
3056             if (fWrotePolyfill) {
3057                 return;
3058             }
3059             fWrotePolyfill = true;
3060 
3061             std::string polyfill = SkSL::String::printf(R"(
3062 struct sampler2D {
3063     texture2d<half> tex;
3064     sampler smp;
3065 };
3066 half4 sample(sampler2D i, float2 p, float b=%g) { return i.tex.sample(i.smp, p, bias(b)); }
3067 half4 sample(sampler2D i, float3 p, float b=%g) { return i.tex.sample(i.smp, p.xy / p.z, bias(b)); }
3068 half4 sampleLod(sampler2D i, float2 p, float lod) { return i.tex.sample(i.smp, p, level(lod)); }
3069 half4 sampleLod(sampler2D i, float3 p, float lod) {
3070     return i.tex.sample(i.smp, p.xy / p.z, level(lod));
3071 }
3072 half4 sampleGrad(sampler2D i, float2 p, float2 dPdx, float2 dPdy) {
3073     return i.tex.sample(i.smp, p, gradient2d(dPdx, dPdy));
3074 }
3075 
3076 )",
3077                                                         fTextureBias,
3078                                                         fTextureBias);
3079             fCodeGen->write(polyfill.c_str());
3080         }
3081 
3082         MetalCodeGenerator* fCodeGen = nullptr;
3083         float fTextureBias = 0.0f;
3084         bool fWrotePolyfill = false;
3085     } visitor;
3086 
3087     visitor.fCodeGen = this;
3088     visitor.fTextureBias = fProgram.fConfig->fSettings.fSharpenTextures ? kSharpenTexturesBias
3089                                                                         : 0.0f;
3090     this->visitGlobalStruct(&visitor);
3091 }
3092 
writeUniformStruct()3093 void MetalCodeGenerator::writeUniformStruct() {
3094     for (const ProgramElement* e : fProgram.elements()) {
3095         if (e->is<GlobalVarDeclaration>()) {
3096             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
3097             const Variable& var = *decls.varDeclaration().var();
3098             if (var.modifierFlags().isUniform()) {
3099                 SkASSERT(var.type().typeKind() != Type::TypeKind::kSampler &&
3100                          var.type().typeKind() != Type::TypeKind::kTexture);
3101                 int uniformSet = this->getUniformSet(var.layout());
3102                 // Make sure that the program's uniform-set value is consistent throughout.
3103                 if (-1 == fUniformBuffer) {
3104                     this->write("struct Uniforms {\n");
3105                     fUniformBuffer = uniformSet;
3106                 } else if (uniformSet != fUniformBuffer) {
3107                     fContext.fErrors->error(decls.fPosition,
3108                             "Metal backend requires all uniforms to have the same "
3109                             "'layout(set=...)'");
3110                 }
3111                 this->write("    ");
3112                 this->writeType(var.type());
3113                 this->write(" ");
3114                 this->writeName(var.mangledName());
3115                 this->write(";\n");
3116             }
3117         }
3118     }
3119     if (-1 != fUniformBuffer) {
3120         this->write("};\n");
3121     }
3122 }
3123 
writeInterpolatedAttributes(const Variable & var)3124 void MetalCodeGenerator::writeInterpolatedAttributes(const Variable& var) {
3125     SkASSERT((is_output(var) && ProgramConfig::IsVertex(fProgram.fConfig->fKind)) ||
3126              (is_input(var) && ProgramConfig::IsFragment(fProgram.fConfig->fKind)));
3127     // Always include the location
3128     this->write(" [[user(locn");
3129     this->write(std::to_string(var.layout().fLocation));
3130     this->write(")");
3131 
3132     if (var.modifierFlags().isFlat()) {
3133         this->write(" flat");
3134     } else if (var.modifierFlags().isNoPerspective()) {
3135         this->write(" center_no_perspective");
3136     } // else default behavior is center_perspective
3137 
3138     this->write("]]");
3139 }
3140 
writeInputStruct()3141 void MetalCodeGenerator::writeInputStruct() {
3142     this->write("struct Inputs {\n");
3143     for (const ProgramElement* e : fProgram.elements()) {
3144         if (e->is<GlobalVarDeclaration>()) {
3145             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
3146             const Variable& var = *decls.varDeclaration().var();
3147             if (is_input(var)) {
3148                 this->write("    ");
3149                 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
3150                     needs_address_space(var.type(), var.modifierFlags())) {
3151                     // TODO: address space support
3152                     this->write("device ");
3153                 }
3154                 this->writeType(var.type());
3155                 if (pass_by_reference(var.type(), var.modifierFlags())) {
3156                     this->write("&");
3157                 }
3158                 this->write(" ");
3159                 this->writeName(var.mangledName());
3160                 if (-1 != var.layout().fLocation) {
3161                     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3162                         this->write("  [[attribute(" + std::to_string(var.layout().fLocation) +
3163                                     ")]]");
3164                     } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3165                         // Write attributes for the fragment input that are consistent with
3166                         // what's annotated on the vertex output.
3167                         this->writeInterpolatedAttributes(var);
3168                     }
3169                 }
3170                 this->write(";\n");
3171             }
3172         }
3173     }
3174     this->write("};\n");
3175 }
3176 
writeOutputStruct()3177 void MetalCodeGenerator::writeOutputStruct() {
3178     this->write("struct Outputs {\n");
3179     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3180         this->write("    float4 sk_Position [[position]];\n");
3181     } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3182         this->write("    half4 sk_FragColor [[color(0)]];\n");
3183         if (fProgram.fInterface.fOutputSecondaryColor) {
3184             this->write("    half4 sk_SecondaryFragColor [[color(0), index(1)]];\n");
3185         }
3186     }
3187     for (const ProgramElement* e : fProgram.elements()) {
3188         if (e->is<GlobalVarDeclaration>()) {
3189             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
3190             const Variable& var = *decls.varDeclaration().var();
3191             if (var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3192                 this->write("    uint sk_SampleMask [[sample_mask]];\n");
3193                 continue;
3194             }
3195             if (is_output(var)) {
3196                 this->write("    ");
3197                 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
3198                     needs_address_space(var.type(), var.modifierFlags())) {
3199                     // TODO: address space support
3200                     this->write("device ");
3201                 }
3202                 this->writeType(var.type());
3203                 if (ProgramConfig::IsCompute(fProgram.fConfig->fKind) &&
3204                     pass_by_reference(var.type(), var.modifierFlags())) {
3205                     this->write("&");
3206                 }
3207                 this->write(" ");
3208                 this->writeName(var.mangledName());
3209 
3210                 int location = var.layout().fLocation;
3211                 if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind) && location < 0 &&
3212                         var.type().typeKind() != Type::TypeKind::kTexture) {
3213                     fContext.fErrors->error(var.fPosition,
3214                                             "Metal out variables must have 'layout(location=...)'");
3215                 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3216                     // Write attributes for the vertex output that are consistent with what's
3217                     // annotated on the fragment input.
3218                     this->writeInterpolatedAttributes(var);
3219                 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
3220                     this->write(" [[color(" + std::to_string(location) + ")");
3221                     int colorIndex = var.layout().fIndex;
3222                     if (colorIndex) {
3223                         this->write(", index(" + std::to_string(colorIndex) + ")");
3224                     }
3225                     this->write("]]");
3226                 }
3227                 this->write(";\n");
3228             }
3229         }
3230     }
3231     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
3232         this->write("    float sk_PointSize [[point_size]];\n");
3233     }
3234     this->write("};\n");
3235 }
3236 
writeInterfaceBlocks()3237 void MetalCodeGenerator::writeInterfaceBlocks() {
3238     bool wroteInterfaceBlock = false;
3239     for (const ProgramElement* e : fProgram.elements()) {
3240         if (e->is<InterfaceBlock>()) {
3241             this->writeInterfaceBlock(e->as<InterfaceBlock>());
3242             wroteInterfaceBlock = true;
3243         }
3244     }
3245     if (!wroteInterfaceBlock &&
3246         fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None) {
3247         this->writeLine("struct sksl_synthetic_uniforms {");
3248         this->writeLine("    float2 " SKSL_RTFLIP_NAME ";");
3249         this->writeLine("};");
3250     }
3251 }
3252 
writeStructDefinitions()3253 void MetalCodeGenerator::writeStructDefinitions() {
3254     for (const ProgramElement* e : fProgram.elements()) {
3255         if (e->is<StructDefinition>()) {
3256             this->writeStructDefinition(e->as<StructDefinition>());
3257         }
3258     }
3259 }
3260 
writeConstantVariables()3261 void MetalCodeGenerator::writeConstantVariables() {
3262     class : public GlobalStructVisitor {
3263     public:
3264         void visitConstantVariable(const VarDeclaration& decl) override {
3265             fCodeGen->write("constant ");
3266             fCodeGen->writeVarDeclaration(decl);
3267             fCodeGen->finishLine();
3268         }
3269 
3270         MetalCodeGenerator* fCodeGen = nullptr;
3271     } visitor;
3272 
3273     visitor.fCodeGen = this;
3274     this->visitGlobalStruct(&visitor);
3275 }
3276 
visitGlobalStruct(GlobalStructVisitor * visitor)3277 void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) {
3278     for (const ProgramElement* element : fProgram.elements()) {
3279         if (element->is<InterfaceBlock>()) {
3280             const auto* ib = &element->as<InterfaceBlock>();
3281             if (ib->typeName() != "sk_PerVertex") {
3282                 visitor->visitInterfaceBlock(*ib, fInterfaceBlockNameMap[&ib->var()->type()]);
3283             }
3284             continue;
3285         }
3286         if (!element->is<GlobalVarDeclaration>()) {
3287             continue;
3288         }
3289         const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
3290         const VarDeclaration& decl = global.varDeclaration();
3291         const Variable& var = *decl.var();
3292         if (decl.baseType().typeKind() == Type::TypeKind::kSampler) {
3293             visitor->visitSampler(var.type(), var.mangledName());
3294             continue;
3295         }
3296         if (decl.baseType().typeKind() == Type::TypeKind::kTexture) {
3297             visitor->visitTexture(var.type(), var.mangledName());
3298             continue;
3299         }
3300         if (!(var.modifierFlags() & ~ModifierFlag::kConst) && var.layout().fBuiltin == -1) {
3301             if (is_in_globals(var)) {
3302                 // Visit a regular global variable.
3303                 visitor->visitNonconstantVariable(var, decl.value().get());
3304             } else {
3305                 // Visit a constant-expression variable.
3306                 SkASSERT(var.modifierFlags().isConst());
3307                 visitor->visitConstantVariable(decl);
3308             }
3309         }
3310     }
3311 }
3312 
writeGlobalStruct()3313 void MetalCodeGenerator::writeGlobalStruct() {
3314     class : public GlobalStructVisitor {
3315     public:
3316         void visitInterfaceBlock(const InterfaceBlock& block,
3317                                  std::string_view blockName) override {
3318             this->addElement();
3319             fCodeGen->write("    ");
3320             if (is_readonly(block)) {
3321                 fCodeGen->write("const ");
3322             }
3323             fCodeGen->write(is_buffer(block) ? "device " : "constant ");
3324             fCodeGen->write(block.typeName());
3325             fCodeGen->write("* ");
3326             fCodeGen->writeName(blockName);
3327             fCodeGen->write(";\n");
3328         }
3329         void visitTexture(const Type& type, std::string_view name) override {
3330             this->addElement();
3331             fCodeGen->write("    ");
3332             fCodeGen->writeType(type);
3333             fCodeGen->write(" ");
3334             fCodeGen->writeName(name);
3335             fCodeGen->write(";\n");
3336         }
3337         void visitSampler(const Type&, std::string_view name) override {
3338             this->addElement();
3339             fCodeGen->write("    sampler2D ");
3340             fCodeGen->writeName(name);
3341             fCodeGen->write(";\n");
3342         }
3343         void visitConstantVariable(const VarDeclaration& decl) override {
3344             // Constants aren't added to the global struct.
3345         }
3346         void visitNonconstantVariable(const Variable& var, const Expression* value) override {
3347             this->addElement();
3348             fCodeGen->write("    ");
3349             fCodeGen->writeModifiers(var.modifierFlags());
3350             fCodeGen->writeType(var.type());
3351             fCodeGen->write(" ");
3352             fCodeGen->writeName(var.mangledName());
3353             fCodeGen->write(";\n");
3354         }
3355         void addElement() {
3356             if (fFirst) {
3357                 fCodeGen->write("struct Globals {\n");
3358                 fFirst = false;
3359             }
3360         }
3361         void finish() {
3362             if (!fFirst) {
3363                 fCodeGen->writeLine("};");
3364                 fFirst = true;
3365             }
3366         }
3367 
3368         MetalCodeGenerator* fCodeGen = nullptr;
3369         bool fFirst = true;
3370     } visitor;
3371 
3372     visitor.fCodeGen = this;
3373     this->visitGlobalStruct(&visitor);
3374     visitor.finish();
3375 }
3376 
writeGlobalInit()3377 void MetalCodeGenerator::writeGlobalInit() {
3378     class : public GlobalStructVisitor {
3379     public:
3380         void visitInterfaceBlock(const InterfaceBlock& blockType,
3381                                  std::string_view blockName) override {
3382             this->addElement();
3383             fCodeGen->write("&");
3384             fCodeGen->writeName(blockName);
3385         }
3386         void visitTexture(const Type&, std::string_view name) override {
3387             this->addElement();
3388             fCodeGen->writeName(name);
3389         }
3390         void visitSampler(const Type&, std::string_view name) override {
3391             this->addElement();
3392             fCodeGen->write("{");
3393             fCodeGen->writeName(name);
3394             fCodeGen->write(kTextureSuffix);
3395             fCodeGen->write(", ");
3396             fCodeGen->writeName(name);
3397             fCodeGen->write(kSamplerSuffix);
3398             fCodeGen->write("}");
3399         }
3400         void visitConstantVariable(const VarDeclaration& decl) override {
3401             // Constant-expression variables aren't put in the global struct.
3402         }
3403         void visitNonconstantVariable(const Variable& var, const Expression* value) override {
3404             this->addElement();
3405             if (value) {
3406                 fCodeGen->writeVarInitializer(var, *value);
3407             } else {
3408                 fCodeGen->write("{}");
3409             }
3410         }
3411         void addElement() {
3412             if (fFirst) {
3413                 fCodeGen->write("Globals _globals{");
3414                 fFirst = false;
3415             } else {
3416                 fCodeGen->write(", ");
3417             }
3418         }
3419         void finish() {
3420             if (!fFirst) {
3421                 fCodeGen->writeLine("};");
3422                 fCodeGen->writeLine("(void)_globals;");
3423             }
3424         }
3425         MetalCodeGenerator* fCodeGen = nullptr;
3426         bool fFirst = true;
3427     } visitor;
3428 
3429     visitor.fCodeGen = this;
3430     this->visitGlobalStruct(&visitor);
3431     visitor.finish();
3432 }
3433 
visitThreadgroupStruct(ThreadgroupStructVisitor * visitor)3434 void MetalCodeGenerator::visitThreadgroupStruct(ThreadgroupStructVisitor* visitor) {
3435     for (const ProgramElement* element : fProgram.elements()) {
3436         if (!element->is<GlobalVarDeclaration>()) {
3437             continue;
3438         }
3439         const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
3440         const VarDeclaration& decl = global.varDeclaration();
3441         const Variable& var = *decl.var();
3442         if (var.modifierFlags().isWorkgroup()) {
3443             SkASSERT(!decl.value());
3444             SkASSERT(!var.modifierFlags().isConst());
3445             visitor->visitNonconstantVariable(var);
3446         }
3447     }
3448 }
3449 
writeThreadgroupStruct()3450 void MetalCodeGenerator::writeThreadgroupStruct() {
3451     class : public ThreadgroupStructVisitor {
3452     public:
3453         void visitNonconstantVariable(const Variable& var) override {
3454             this->addElement();
3455             fCodeGen->write("    ");
3456             fCodeGen->writeModifiers(var.modifierFlags());
3457             fCodeGen->writeType(var.type());
3458             fCodeGen->write(" ");
3459             fCodeGen->writeName(var.mangledName());
3460             fCodeGen->write(";\n");
3461         }
3462         void addElement() {
3463             if (fFirst) {
3464                 fCodeGen->write("struct Threadgroups {\n");
3465                 fFirst = false;
3466             }
3467         }
3468         void finish() {
3469             if (!fFirst) {
3470                 fCodeGen->writeLine("};");
3471                 fFirst = true;
3472             }
3473         }
3474 
3475         MetalCodeGenerator* fCodeGen = nullptr;
3476         bool fFirst = true;
3477     } visitor;
3478 
3479     visitor.fCodeGen = this;
3480     this->visitThreadgroupStruct(&visitor);
3481     visitor.finish();
3482 }
3483 
writeThreadgroupInit()3484 void MetalCodeGenerator::writeThreadgroupInit() {
3485     class : public ThreadgroupStructVisitor {
3486     public:
3487         void visitNonconstantVariable(const Variable& var) override {
3488             this->addElement();
3489             fCodeGen->write("{}");
3490         }
3491         void addElement() {
3492             if (fFirst) {
3493                 fCodeGen->write("threadgroup Threadgroups _threadgroups{");
3494                 fFirst = false;
3495             } else {
3496                 fCodeGen->write(", ");
3497             }
3498         }
3499         void finish() {
3500             if (!fFirst) {
3501                 fCodeGen->writeLine("};");
3502                 fCodeGen->writeLine("(void)_threadgroups;");
3503             }
3504         }
3505         MetalCodeGenerator* fCodeGen = nullptr;
3506         bool fFirst = true;
3507     } visitor;
3508 
3509     visitor.fCodeGen = this;
3510     this->visitThreadgroupStruct(&visitor);
3511     visitor.finish();
3512 }
3513 
writeProgramElement(const ProgramElement & e)3514 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
3515     switch (e.kind()) {
3516         case ProgramElement::Kind::kExtension:
3517             break;
3518         case ProgramElement::Kind::kGlobalVar:
3519             break;
3520         case ProgramElement::Kind::kInterfaceBlock:
3521             // Handled in writeInterfaceBlocks; do nothing.
3522             break;
3523         case ProgramElement::Kind::kStructDefinition:
3524             // Handled in writeStructDefinitions; do nothing.
3525             break;
3526         case ProgramElement::Kind::kFunction:
3527             this->writeFunction(e.as<FunctionDefinition>());
3528             break;
3529         case ProgramElement::Kind::kFunctionPrototype:
3530             this->writeFunctionPrototype(e.as<FunctionPrototype>());
3531             break;
3532         case ProgramElement::Kind::kModifiers:
3533             // Not necessary in Metal; do nothing.
3534             break;
3535         default:
3536             SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
3537             break;
3538     }
3539 }
3540 
requirements(const Statement * s)3541 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) {
3542     class RequirementsVisitor : public ProgramVisitor {
3543     public:
3544         using ProgramVisitor::visitStatement;
3545 
3546         bool visitExpression(const Expression& e) override {
3547             switch (e.kind()) {
3548                 case Expression::Kind::kFunctionCall: {
3549                     const FunctionCall& f = e.as<FunctionCall>();
3550                     fRequirements |= fCodeGen->requirements(f.function());
3551                     break;
3552                 }
3553                 case Expression::Kind::kFieldAccess: {
3554                     const FieldAccess& f = e.as<FieldAccess>();
3555                     if (f.ownerKind() == FieldAccess::OwnerKind::kAnonymousInterfaceBlock) {
3556                         fRequirements |= kGlobals_Requirement;
3557                         return false;  // don't recurse into the base variable
3558                     }
3559                     break;
3560                 }
3561                 case Expression::Kind::kVariableReference: {
3562                     const Variable& var = *e.as<VariableReference>().variable();
3563 
3564                     if (var.layout().fBuiltin == SK_FRAGCOORD_BUILTIN) {
3565                         fRequirements |= kGlobals_Requirement | kFragCoord_Requirement;
3566                     } else if (var.layout().fBuiltin == SK_SAMPLEMASKIN_BUILTIN) {
3567                         fRequirements |= kSampleMaskIn_Requirement;
3568                     } else if (var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3569                         fRequirements |= kOutputs_Requirement;
3570                     } else if (var.layout().fBuiltin == SK_VERTEXID_BUILTIN) {
3571                         fRequirements |= kVertexID_Requirement;
3572                     } else if (var.layout().fBuiltin == SK_INSTANCEID_BUILTIN) {
3573                         fRequirements |= kInstanceID_Requirement;
3574                     } else if (var.storage() == Variable::Storage::kGlobal) {
3575                         if (is_input(var)) {
3576                             fRequirements |= kInputs_Requirement;
3577                         } else if (is_output(var)) {
3578                             fRequirements |= kOutputs_Requirement;
3579                         } else if (is_uniforms(var)) {
3580                             fRequirements |= kUniforms_Requirement;
3581                         } else if (is_threadgroup(var)) {
3582                             fRequirements |= kThreadgroups_Requirement;
3583                         } else if (is_in_globals(var)) {
3584                             fRequirements |= kGlobals_Requirement;
3585                         }
3586                     }
3587                     break;
3588                 }
3589                 default:
3590                     break;
3591             }
3592             return ProgramVisitor::visitExpression(e);
3593         }
3594 
3595         MetalCodeGenerator* fCodeGen;
3596         Requirements fRequirements = kNo_Requirements;
3597     };
3598 
3599     RequirementsVisitor visitor;
3600     if (s) {
3601         visitor.fCodeGen = this;
3602         visitor.visitStatement(*s);
3603     }
3604     return visitor.fRequirements;
3605 }
3606 
requirements(const FunctionDeclaration & f)3607 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
3608     Requirements* found = fRequirements.find(&f);
3609     if (!found) {
3610         fRequirements.set(&f, kNo_Requirements);
3611         for (const ProgramElement* e : fProgram.elements()) {
3612             if (e->is<FunctionDefinition>()) {
3613                 const FunctionDefinition& def = e->as<FunctionDefinition>();
3614                 if (&def.declaration() == &f) {
3615                     Requirements reqs = this->requirements(def.body().get());
3616                     fRequirements.set(&f, reqs);
3617                     return reqs;
3618                 }
3619             }
3620         }
3621 
3622         // We never found a definition for this declared function, but it's legal to prototype a
3623         // function without ever giving a definition, as long as you don't call it.
3624         return kNo_Requirements;
3625     }
3626     return *found;
3627 }
3628 
generateCode()3629 bool MetalCodeGenerator::generateCode() {
3630     StringStream header;
3631     {
3632         AutoOutputStream outputToHeader(this, &header, &fIndentation);
3633         this->writeHeader();
3634         this->writeConstantVariables();
3635         this->writeSampler2DPolyfill();
3636         this->writeStructDefinitions();
3637         this->writeUniformStruct();
3638         this->writeInputStruct();
3639         if (!ProgramConfig::IsCompute(fProgram.fConfig->fKind)) {
3640             this->writeOutputStruct();
3641         }
3642         this->writeInterfaceBlocks();
3643         this->writeGlobalStruct();
3644         this->writeThreadgroupStruct();
3645 
3646         // Emit prototypes for every built-in function; these aren't always added in perfect order.
3647         for (const ProgramElement* e : fProgram.fSharedElements) {
3648             if (e->is<FunctionDefinition>()) {
3649                 this->writeFunctionDeclaration(e->as<FunctionDefinition>().declaration());
3650                 this->writeLine(";");
3651             }
3652         }
3653     }
3654     StringStream body;
3655     {
3656         AutoOutputStream outputToBody(this, &body, &fIndentation);
3657 
3658         for (const ProgramElement* e : fProgram.elements()) {
3659             this->writeProgramElement(*e);
3660         }
3661     }
3662     write_stringstream(header, *fOut);
3663     write_stringstream(fExtraFunctionPrototypes, *fOut);
3664     write_stringstream(fExtraFunctions, *fOut);
3665     write_stringstream(body, *fOut);
3666     return fContext.fErrors->errorCount() == 0;
3667 }
3668 
ToMetal(Program & program,const ShaderCaps * caps,OutputStream & out,PrettyPrint pp)3669 bool ToMetal(Program& program, const ShaderCaps* caps, OutputStream& out, PrettyPrint pp) {
3670     TRACE_EVENT0("skia.shaders", "SkSL::ToMetal");
3671     SkASSERT(caps != nullptr);
3672 
3673     program.fContext->fErrors->setSource(*program.fSource);
3674     MetalCodeGenerator cg(program.fContext.get(), caps, &program, &out, pp);
3675     bool result = cg.generateCode();
3676     program.fContext->fErrors->setSource(std::string_view());
3677 
3678     return result;
3679 }
3680 
ToMetal(Program & program,const ShaderCaps * caps,OutputStream & out)3681 bool ToMetal(Program& program, const ShaderCaps* caps, OutputStream& out) {
3682 #if defined(SK_DEBUG)
3683     constexpr PrettyPrint defaultPrintOpts = PrettyPrint::kYes;
3684 #else
3685     constexpr PrettyPrint defaultPrintOpts = PrettyPrint::kNo;
3686 #endif
3687     return ToMetal(program, caps, out, defaultPrintOpts);
3688 }
3689 
ToMetal(Program & program,const ShaderCaps * caps,std::string * out)3690 bool ToMetal(Program& program, const ShaderCaps* caps, std::string* out) {
3691     StringStream buffer;
3692     if (!ToMetal(program, caps, buffer)) {
3693         return false;
3694     }
3695     *out = buffer.str();
3696     return true;
3697 }
3698 
3699 }  // namespace SkSL
3700