xref: /aosp_15_r20/external/skia/src/sksl/codegen/SkSLPipelineStageCodeGenerator.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2018 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLPipelineStageCodeGenerator.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkTArray.h"
13 #include "src/base/SkEnumBitMask.h"
14 #include "src/core/SkTHash.h"
15 #include "src/sksl/SkSLBuiltinTypes.h"
16 #include "src/sksl/SkSLContext.h"  // IWYU pragma: keep
17 #include "src/sksl/SkSLDefines.h"
18 #include "src/sksl/SkSLIntrinsicList.h"
19 #include "src/sksl/SkSLModule.h"
20 #include "src/sksl/SkSLOperator.h"
21 #include "src/sksl/SkSLProgramSettings.h"
22 #include "src/sksl/SkSLString.h"
23 #include "src/sksl/SkSLStringStream.h"
24 #include "src/sksl/analysis/SkSLSpecialization.h"
25 #include "src/sksl/ir/SkSLBinaryExpression.h"
26 #include "src/sksl/ir/SkSLBlock.h"
27 #include "src/sksl/ir/SkSLChildCall.h"
28 #include "src/sksl/ir/SkSLConstructor.h"
29 #include "src/sksl/ir/SkSLDoStatement.h"
30 #include "src/sksl/ir/SkSLExpression.h"
31 #include "src/sksl/ir/SkSLExpressionStatement.h"
32 #include "src/sksl/ir/SkSLFieldAccess.h"
33 #include "src/sksl/ir/SkSLForStatement.h"
34 #include "src/sksl/ir/SkSLFunctionCall.h"
35 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
36 #include "src/sksl/ir/SkSLFunctionDefinition.h"
37 #include "src/sksl/ir/SkSLIRNode.h"
38 #include "src/sksl/ir/SkSLIfStatement.h"
39 #include "src/sksl/ir/SkSLIndexExpression.h"
40 #include "src/sksl/ir/SkSLModifierFlags.h"
41 #include "src/sksl/ir/SkSLPostfixExpression.h"
42 #include "src/sksl/ir/SkSLPrefixExpression.h"
43 #include "src/sksl/ir/SkSLProgram.h"
44 #include "src/sksl/ir/SkSLProgramElement.h"
45 #include "src/sksl/ir/SkSLReturnStatement.h"
46 #include "src/sksl/ir/SkSLStatement.h"
47 #include "src/sksl/ir/SkSLStructDefinition.h"
48 #include "src/sksl/ir/SkSLSwitchCase.h"
49 #include "src/sksl/ir/SkSLSwitchStatement.h"
50 #include "src/sksl/ir/SkSLSwizzle.h"
51 #include "src/sksl/ir/SkSLTernaryExpression.h"
52 #include "src/sksl/ir/SkSLType.h"
53 #include "src/sksl/ir/SkSLVarDeclarations.h"
54 #include "src/sksl/ir/SkSLVariable.h"
55 #include "src/sksl/ir/SkSLVariableReference.h"
56 #include "src/utils/SkBitSet.h"
57 
58 #include <functional>
59 #include <memory>
60 #include <string_view>
61 #include <utility>
62 
63 using namespace skia_private;
64 
65 namespace SkSL {
66 namespace PipelineStage {
67 
68 class PipelineStageCodeGenerator {
69 public:
PipelineStageCodeGenerator(const Program & program,const char * sampleCoords,const char * inputColor,const char * destColor,Callbacks * callbacks)70     PipelineStageCodeGenerator(const Program& program,
71                                const char* sampleCoords,
72                                const char* inputColor,
73                                const char* destColor,
74                                Callbacks* callbacks)
75             : fProgram(program)
76             , fSampleCoords(sampleCoords)
77             , fInputColor(inputColor)
78             , fDestColor(destColor)
79             , fCallbacks(callbacks) {}
80 
81     void generateCode();
82 
83 private:
84     using Precedence = OperatorPrecedence;
85 
86     void write(std::string_view s);
87     void writeLine(std::string_view s = std::string_view());
88 
89     std::string typeName(const Type& type);
90     void writeType(const Type& type);
91 
92     std::string functionName(const FunctionDeclaration& decl,
93                              Analysis::SpecializationIndex specIndex);
94     void writeFunction(const FunctionDefinition& f);
95     void writeFunctionDeclaration(const FunctionDeclaration& decl);
96 
97     void forEachSpecialization(const FunctionDeclaration& decl, const std::function<void()>& fn);
98 
99     std::string modifierString(ModifierFlags modifiers);
100     std::string functionDeclaration(const FunctionDeclaration& decl);
101 
102     // Handles arrays correctly, eg: `float x[2]`
103     std::string typedVariable(const Type& type, std::string_view name);
104 
105     void writeVarDeclaration(const VarDeclaration& var);
106     void writeGlobalVarDeclaration(const GlobalVarDeclaration& g);
107     void writeStructDefinition(const StructDefinition& s);
108 
109     void writeExpression(const Expression& expr, Precedence parentPrecedence);
110     void writeChildCall(const ChildCall& c);
111     void writeFunctionCall(const FunctionCall& c);
112     void writeAnyConstructor(const AnyConstructor& c, Precedence parentPrecedence);
113     void writeFieldAccess(const FieldAccess& f);
114     void writeSwizzle(const Swizzle& swizzle);
115     void writeBinaryExpression(const BinaryExpression& b, Precedence parentPrecedence);
116     void writeTernaryExpression(const TernaryExpression& t, Precedence parentPrecedence);
117     void writeIndexExpression(const IndexExpression& expr);
118     void writePrefixExpression(const PrefixExpression& p, Precedence parentPrecedence);
119     void writePostfixExpression(const PostfixExpression& p, Precedence parentPrecedence);
120     void writeVariableReference(const VariableReference& ref);
121 
122     void writeStatement(const Statement& s);
123     void writeBlock(const Block& b);
124     void writeIfStatement(const IfStatement& stmt);
125     void writeDoStatement(const DoStatement& d);
126     void writeForStatement(const ForStatement& f);
127     void writeReturnStatement(const ReturnStatement& r);
128     void writeSwitchStatement(const SwitchStatement& s);
129 
130     void writeProgramElementFirstPass(const ProgramElement& e);
131     void writeProgramElementSecondPass(const ProgramElement& e);
132 
133     struct AutoOutputBuffer {
AutoOutputBufferSkSL::PipelineStage::PipelineStageCodeGenerator::AutoOutputBuffer134         AutoOutputBuffer(PipelineStageCodeGenerator* generator) : fGenerator(generator) {
135             fOldBuffer = fGenerator->fBuffer;
136             fGenerator->fBuffer = &fBuffer;
137         }
138 
~AutoOutputBufferSkSL::PipelineStage::PipelineStageCodeGenerator::AutoOutputBuffer139         ~AutoOutputBuffer() {
140             fGenerator->fBuffer = fOldBuffer;
141         }
142 
143         PipelineStageCodeGenerator* fGenerator;
144         StringStream*               fOldBuffer;
145         StringStream                fBuffer;
146     };
147 
148     const Program& fProgram;
149     const char*    fSampleCoords;
150     const char*    fInputColor;
151     const char*    fDestColor;
152     Callbacks*     fCallbacks;
153 
154     Analysis::SpecializationInfo fSpecializationInfo;
155     Analysis::SpecializationIndex fActiveSpecializationIndex = Analysis::kUnspecialized;
156     const Analysis::SpecializedParameters* fActiveSpecialization = nullptr;
157 
158     THashMap<const Variable*, std::string>                                      fVariableNames;
159     THashMap<const Type*, std::string>                                          fStructNames;
160     THashMap<Analysis::SpecializedFunctionKey, std::string, Analysis::SpecializedFunctionKey::Hash>
161             fFunctionNames;
162 
163     StringStream*              fBuffer = nullptr;
164     bool                       fCastReturnsToHalf = false;
165     const FunctionDeclaration* fCurrentFunction = nullptr;
166 };
167 
168 
write(std::string_view s)169 void PipelineStageCodeGenerator::write(std::string_view s) {
170     fBuffer->write(s.data(), s.length());
171 }
172 
writeLine(std::string_view s)173 void PipelineStageCodeGenerator::writeLine(std::string_view s) {
174     fBuffer->write(s.data(), s.length());
175     fBuffer->writeText("\n");
176 }
177 
writeChildCall(const ChildCall & c)178 void PipelineStageCodeGenerator::writeChildCall(const ChildCall& c) {
179     const Variable* child = &c.child();
180 
181     if (fActiveSpecialization) {
182         const Expression** specializedChild = fActiveSpecialization->find(child);
183         if (specializedChild) {
184             SkASSERT(*specializedChild);
185             child = (*specializedChild)->as<VariableReference>().variable();
186         }
187     }
188 
189     const ExpressionArray& arguments = c.arguments();
190     SkASSERT(!arguments.empty());
191     int index = 0;
192     bool found = false;
193     for (const ProgramElement* p : fProgram.elements()) {
194         if (p->is<GlobalVarDeclaration>()) {
195             const GlobalVarDeclaration& global = p->as<GlobalVarDeclaration>();
196             const VarDeclaration& decl = global.varDeclaration();
197             if (decl.var() == child) {
198                 found = true;
199             } else if (decl.var()->type().isEffectChild()) {
200                 ++index;
201             }
202         }
203         if (found) {
204             break;
205         }
206     }
207     SkASSERT(found);
208 
209     // Shaders require a coordinate argument. Color filters require a color argument.
210     // Blenders require two color arguments.
211     std::string sampleOutput;
212     {
213         AutoOutputBuffer exprBuffer(this);
214         this->writeExpression(*arguments[0], Precedence::kSequence);
215 
216         switch (c.child().type().typeKind()) {
217             case Type::TypeKind::kShader: {
218                 SkASSERT(arguments.size() == 1);
219                 SkASSERT(arguments[0]->type().matches(*fProgram.fContext->fTypes.fFloat2));
220                 sampleOutput = fCallbacks->sampleShader(index, exprBuffer.fBuffer.str());
221                 break;
222             }
223             case Type::TypeKind::kColorFilter: {
224                 SkASSERT(arguments.size() == 1);
225                 SkASSERT(arguments[0]->type().matches(*fProgram.fContext->fTypes.fHalf4) ||
226                          arguments[0]->type().matches(*fProgram.fContext->fTypes.fFloat4));
227                 sampleOutput = fCallbacks->sampleColorFilter(index, exprBuffer.fBuffer.str());
228                 break;
229             }
230             case Type::TypeKind::kBlender: {
231                 SkASSERT(arguments.size() == 2);
232                 SkASSERT(arguments[0]->type().matches(*fProgram.fContext->fTypes.fHalf4) ||
233                          arguments[0]->type().matches(*fProgram.fContext->fTypes.fFloat4));
234                 SkASSERT(arguments[1]->type().matches(*fProgram.fContext->fTypes.fHalf4) ||
235                          arguments[1]->type().matches(*fProgram.fContext->fTypes.fFloat4));
236 
237                 AutoOutputBuffer exprBuffer2(this);
238                 this->writeExpression(*arguments[1], Precedence::kSequence);
239 
240                 sampleOutput = fCallbacks->sampleBlender(index, exprBuffer.fBuffer.str(),
241                                                                 exprBuffer2.fBuffer.str());
242                 break;
243             }
244             default: {
245                 SkDEBUGFAILF("cannot sample from type '%s'", child->type().description().c_str());
246             }
247         }
248     }
249     this->write(sampleOutput);
250 }
251 
writeFunctionCall(const FunctionCall & c)252 void PipelineStageCodeGenerator::writeFunctionCall(const FunctionCall& c) {
253     const FunctionDeclaration& function = c.function();
254 
255     if (function.intrinsicKind() == IntrinsicKind::k_toLinearSrgb_IntrinsicKind ||
256         function.intrinsicKind() == IntrinsicKind::k_fromLinearSrgb_IntrinsicKind) {
257         SkASSERT(c.arguments().size() == 1);
258         std::string colorArg;
259         {
260             AutoOutputBuffer exprBuffer(this);
261             this->writeExpression(*c.arguments()[0], Precedence::kSequence);
262             colorArg = exprBuffer.fBuffer.str();
263         }
264 
265         switch (function.intrinsicKind()) {
266             case IntrinsicKind::k_toLinearSrgb_IntrinsicKind:
267                 this->write(fCallbacks->toLinearSrgb(std::move(colorArg)));
268                 break;
269             case IntrinsicKind::k_fromLinearSrgb_IntrinsicKind:
270                 this->write(fCallbacks->fromLinearSrgb(std::move(colorArg)));
271                 break;
272             default:
273                 SkUNREACHABLE;
274         }
275 
276         return;
277     }
278 
279     // Look up the specialization data, if any, needed for this function call.
280     Analysis::SpecializationIndex callIndex = Analysis::FindSpecializationIndexForCall(
281             c, fSpecializationInfo, fActiveSpecializationIndex);
282     SkBitSet specializedParams =
283             Analysis::FindSpecializedParametersForFunction(function, fSpecializationInfo);
284 
285     this->write(this->functionName(function, callIndex));
286     this->write("(");
287     auto separator = SkSL::String::Separator();
288     for (int argIdx = 0; argIdx < c.arguments().size(); ++argIdx) {
289         // If this parameter is specialized, it is baked into the destination function and should
290         // not be passed along as an argument.
291         if (specializedParams.test(argIdx)) {
292             continue;
293         }
294 
295         // This is a regular argument and must be passed normally.
296         this->write(separator());
297         this->writeExpression(*c.arguments()[argIdx], Precedence::kSequence);
298     }
299     this->write(")");
300 }
301 
writeVariableReference(const VariableReference & ref)302 void PipelineStageCodeGenerator::writeVariableReference(const VariableReference& ref) {
303     const Variable* var = ref.variable();
304 
305     if (fCurrentFunction && var == fCurrentFunction->getMainCoordsParameter()) {
306         this->write(fSampleCoords);
307         return;
308     }
309     if (fCurrentFunction && var == fCurrentFunction->getMainInputColorParameter()) {
310         this->write(fInputColor);
311         return;
312     }
313     if (fCurrentFunction && var == fCurrentFunction->getMainDestColorParameter()) {
314         this->write(fDestColor);
315         return;
316     }
317 
318     std::string* name = fVariableNames.find(var);
319     this->write(name ? *name : var->name());
320 }
321 
writeIfStatement(const IfStatement & stmt)322 void PipelineStageCodeGenerator::writeIfStatement(const IfStatement& stmt) {
323     this->write("if (");
324     this->writeExpression(*stmt.test(), Precedence::kExpression);
325     this->write(") ");
326     this->writeStatement(*stmt.ifTrue());
327     if (stmt.ifFalse()) {
328         this->write(" else ");
329         this->writeStatement(*stmt.ifFalse());
330     }
331 }
332 
writeReturnStatement(const ReturnStatement & r)333 void PipelineStageCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
334     this->write("return");
335     if (r.expression()) {
336         this->write(" ");
337         if (fCastReturnsToHalf) {
338             this->write("half4(");
339         }
340         this->writeExpression(*r.expression(), Precedence::kExpression);
341         if (fCastReturnsToHalf) {
342             this->write(")");
343         }
344     }
345     this->write(";");
346 }
347 
writeSwitchStatement(const SwitchStatement & s)348 void PipelineStageCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
349     this->write("switch (");
350     this->writeExpression(*s.value(), Precedence::kExpression);
351     this->writeLine(") {");
352     for (const std::unique_ptr<Statement>& stmt : s.cases()) {
353         const SwitchCase& c = stmt->as<SwitchCase>();
354         if (c.isDefault()) {
355             this->writeLine("default:");
356         } else {
357             this->write("case ");
358             this->write(std::to_string(c.value()));
359             this->writeLine(":");
360         }
361         if (!c.statement()->isEmpty()) {
362             this->writeStatement(*c.statement());
363             this->writeLine();
364         }
365     }
366     this->writeLine();
367     this->write("}");
368 }
369 
functionName(const FunctionDeclaration & decl,Analysis::SpecializationIndex specIndex)370 std::string PipelineStageCodeGenerator::functionName(const FunctionDeclaration& decl,
371                                                      Analysis::SpecializationIndex specIndex) {
372     if (decl.isMain()) {
373         return std::string(fCallbacks->getMainName());
374     }
375 
376     // Intrinsics and functions from sksl_shared do not use name mangling.
377     if (decl.isIntrinsic() || decl.moduleType() == ModuleType::sksl_shared) {
378         return std::string(decl.name());
379     }
380 
381     Analysis::SpecializedFunctionKey key{&decl, specIndex};
382     if (std::string* name = fFunctionNames.find(key)) {
383         return *name;
384     }
385 
386     std::string specializedName = std::string(decl.name());
387 
388     // For specialized functions, tack on `_param1_param2` to the function name.
389     Analysis::GetParameterMappingsForFunction(decl, fSpecializationInfo, specIndex,
390                                               [&](int, const Variable*, const Expression* expr) {
391                                                   specializedName += '_';
392                                                   specializedName += expr->description();
393                                               });
394 
395     std::string mangledName = fCallbacks->getMangledName(specializedName.c_str());
396     fFunctionNames.set(key, mangledName);
397     return mangledName;
398 }
399 
writeFunction(const FunctionDefinition & f)400 void PipelineStageCodeGenerator::writeFunction(const FunctionDefinition& f) {
401     // Don't re-emit functions from sksl_shared. (Functions from the `sksl_rt_shader` module won't
402     // be visible once the shader is converted into a pipeline stage, so we do emit those.)
403     if (f.declaration().moduleType() == ModuleType::sksl_shared) {
404         return;
405     }
406 
407     SkASSERT(!fCurrentFunction);
408     fCurrentFunction = &f.declaration();
409 
410     // We allow public SkSL's main() to return half4 _or_ float4 (i.e. vec4). When we emit our code
411     // in the processor, the surrounding code is going to expect half4, so we explicitly cast any
412     // returns (from main) to half4. This is only strictly necessary if the return type is float4,
413     // but we inject it unconditionally as a defensive measure, since it is free and harmless.
414     const FunctionDeclaration& decl = f.declaration();
415     if (decl.isMain() && !ProgramConfig::IsMesh(fProgram.fConfig->fKind)) {
416         fCastReturnsToHalf = true;
417     }
418 
419     this->forEachSpecialization(*fCurrentFunction, [&] {
420         // Assemble the function body into a separate output stream.
421         AutoOutputBuffer body(this);
422         for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
423             this->writeStatement(*stmt);
424             this->writeLine();
425         }
426 
427         // Emit the function.
428         fCallbacks->defineFunction(this->functionDeclaration(decl).c_str(),
429                                    body.fBuffer.str().c_str(),
430                                    decl.isMain());
431     });
432 
433     if (decl.isMain()) {
434         fCastReturnsToHalf = false;
435     }
436 
437     fCurrentFunction = nullptr;
438 }
439 
functionDeclaration(const FunctionDeclaration & decl)440 std::string PipelineStageCodeGenerator::functionDeclaration(const FunctionDeclaration& decl) {
441     // This is similar to decl.description(), but substitutes a mangled name, and handles modifiers
442     // on the function (e.g. `inline`) and its parameters (e.g. `inout`).
443     std::string declString =
444             String::printf("%s%s%s %s(",
445                            decl.modifierFlags().isInline() ? "inline " : "",
446                            decl.modifierFlags().isNoInline() ? "noinline " : "",
447                            this->typeName(decl.returnType()).c_str(),
448                            this->functionName(decl, fActiveSpecializationIndex).c_str());
449 
450     auto separator = SkSL::String::Separator();
451     for (const Variable* p : decl.parameters()) {
452         // Skip past parameters that we are specializing.
453         bool paramIsSpecialized = fActiveSpecialization && fActiveSpecialization->find(p);
454         if (!paramIsSpecialized) {
455             declString.append(separator());
456             declString.append(this->modifierString(p->modifierFlags()));
457             declString.append(this->typedVariable(p->type(), p->name()).c_str());
458         }
459     }
460 
461     return declString + ")";
462 }
463 
writeFunctionDeclaration(const FunctionDeclaration & decl)464 void PipelineStageCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& decl) {
465     if (!decl.isMain() && decl.moduleType() != ModuleType::sksl_shared) {
466         this->forEachSpecialization(decl, [&] {
467             std::string prototype = this->functionDeclaration(decl) + ';';
468             fCallbacks->declareFunction(prototype.c_str());
469         });
470     }
471 }
472 
forEachSpecialization(const FunctionDeclaration & decl,const std::function<void ()> & fn)473 void PipelineStageCodeGenerator::forEachSpecialization(const FunctionDeclaration& decl,
474                                                        const std::function<void()>& fn) {
475     // Save off the current specialization.
476     Analysis::SpecializationIndex prevSpecializationIndex = fActiveSpecializationIndex;
477     const Analysis::SpecializedParameters* prevSpecialization = fActiveSpecialization;
478 
479     if (const Analysis::Specializations* specializations =
480                 fSpecializationInfo.fSpecializationMap.find(&decl)) {
481         // Invoke the callback for each specialization.
482         for (fActiveSpecializationIndex = 0;
483              fActiveSpecializationIndex < specializations->size();
484              ++fActiveSpecializationIndex) {
485             fActiveSpecialization = &specializations->at(fActiveSpecializationIndex);
486             fn();
487         }
488     } else {
489         // This function isn't specialized, so emit its declaration normally.
490         fActiveSpecializationIndex = Analysis::kUnspecialized;
491         fActiveSpecialization = nullptr;
492         fn();
493     }
494 
495     // Restore the previous specialization.
496     fActiveSpecializationIndex = prevSpecializationIndex;
497     fActiveSpecialization = prevSpecialization;
498 }
499 
writeGlobalVarDeclaration(const GlobalVarDeclaration & g)500 void PipelineStageCodeGenerator::writeGlobalVarDeclaration(const GlobalVarDeclaration& g) {
501     const VarDeclaration& decl = g.varDeclaration();
502     const Variable& var = *decl.var();
503 
504     if (var.isBuiltin() || var.type().isOpaque()) {
505         // Don't re-declare these. (eg, sk_FragCoord, or fragmentProcessor children)
506     } else if (var.modifierFlags().isUniform()) {
507         std::string uniformName = fCallbacks->declareUniform(&decl);
508         fVariableNames.set(&var, std::move(uniformName));
509     } else {
510         std::string mangledName = fCallbacks->getMangledName(std::string(var.name()).c_str());
511         std::string declaration = this->modifierString(var.modifierFlags()) +
512                                   this->typedVariable(var.type(), mangledName);
513         if (decl.value()) {
514             AutoOutputBuffer outputToBuffer(this);
515             this->writeExpression(*decl.value(), Precedence::kExpression);
516             declaration += " = ";
517             declaration += outputToBuffer.fBuffer.str();
518         }
519         declaration += ";\n";
520         fCallbacks->declareGlobal(declaration.c_str());
521         fVariableNames.set(&var, std::move(mangledName));
522     }
523 }
524 
writeStructDefinition(const StructDefinition & s)525 void PipelineStageCodeGenerator::writeStructDefinition(const StructDefinition& s) {
526     const Type& type = s.type();
527     std::string mangledName = fCallbacks->getMangledName(type.displayName().c_str());
528     std::string definition = "struct " + mangledName + " {\n";
529     for (const auto& f : type.fields()) {
530         definition += this->typedVariable(*f.fType, f.fName) + ";\n";
531     }
532     definition += "};\n";
533     fStructNames.set(&type, std::move(mangledName));
534     fCallbacks->defineStruct(definition.c_str());
535 }
536 
writeProgramElementFirstPass(const ProgramElement & e)537 void PipelineStageCodeGenerator::writeProgramElementFirstPass(const ProgramElement& e) {
538     switch (e.kind()) {
539         case ProgramElement::Kind::kGlobalVar:
540             this->writeGlobalVarDeclaration(e.as<GlobalVarDeclaration>());
541             break;
542         case ProgramElement::Kind::kFunction:
543             this->writeFunctionDeclaration(e.as<FunctionDefinition>().declaration());
544             break;
545         case ProgramElement::Kind::kFunctionPrototype:
546             // Skip this; we're already emitting prototypes for every FunctionDefinition.
547             // (See case kFunction, directly above.)
548             break;
549         case ProgramElement::Kind::kStructDefinition:
550             this->writeStructDefinition(e.as<StructDefinition>());
551             break;
552 
553         case ProgramElement::Kind::kExtension:
554         case ProgramElement::Kind::kInterfaceBlock:
555         case ProgramElement::Kind::kModifiers:
556         default:
557             SkDEBUGFAILF("unsupported program element %s\n", e.description().c_str());
558             break;
559     }
560 }
561 
writeProgramElementSecondPass(const ProgramElement & e)562 void PipelineStageCodeGenerator::writeProgramElementSecondPass(const ProgramElement& e) {
563     if (e.is<FunctionDefinition>()) {
564         this->writeFunction(e.as<FunctionDefinition>());
565     }
566 }
567 
typeName(const Type & raw)568 std::string PipelineStageCodeGenerator::typeName(const Type& raw) {
569     const Type& type = raw.resolve().scalarTypeForLiteral();
570     if (type.isArray()) {
571         // This is necessary so that name mangling on arrays-of-structs works properly.
572         std::string arrayName = this->typeName(type.componentType());
573         arrayName.push_back('[');
574         arrayName += std::to_string(type.columns());
575         arrayName.push_back(']');
576         return arrayName;
577     }
578 
579     std::string* name = fStructNames.find(&type);
580     return name ? *name : std::string(type.name());
581 }
582 
writeType(const Type & type)583 void PipelineStageCodeGenerator::writeType(const Type& type) {
584     this->write(this->typeName(type));
585 }
586 
writeExpression(const Expression & expr,Precedence parentPrecedence)587 void PipelineStageCodeGenerator::writeExpression(const Expression& expr,
588                                                  Precedence parentPrecedence) {
589     switch (expr.kind()) {
590         case Expression::Kind::kBinary:
591             this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
592             break;
593         case Expression::Kind::kLiteral:
594         case Expression::Kind::kSetting:
595             this->write(expr.description());
596             break;
597         case Expression::Kind::kChildCall:
598             this->writeChildCall(expr.as<ChildCall>());
599             break;
600         case Expression::Kind::kConstructorArray:
601         case Expression::Kind::kConstructorArrayCast:
602         case Expression::Kind::kConstructorCompound:
603         case Expression::Kind::kConstructorCompoundCast:
604         case Expression::Kind::kConstructorDiagonalMatrix:
605         case Expression::Kind::kConstructorMatrixResize:
606         case Expression::Kind::kConstructorScalarCast:
607         case Expression::Kind::kConstructorSplat:
608         case Expression::Kind::kConstructorStruct:
609             this->writeAnyConstructor(expr.asAnyConstructor(), parentPrecedence);
610             break;
611         case Expression::Kind::kEmpty:
612             this->write("false");
613             break;
614         case Expression::Kind::kFieldAccess:
615             this->writeFieldAccess(expr.as<FieldAccess>());
616             break;
617         case Expression::Kind::kFunctionCall:
618             this->writeFunctionCall(expr.as<FunctionCall>());
619             break;
620         case Expression::Kind::kPrefix:
621             this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
622             break;
623         case Expression::Kind::kPostfix:
624             this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
625             break;
626         case Expression::Kind::kSwizzle:
627             this->writeSwizzle(expr.as<Swizzle>());
628             break;
629         case Expression::Kind::kVariableReference:
630             this->writeVariableReference(expr.as<VariableReference>());
631             break;
632         case Expression::Kind::kTernary:
633             this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
634             break;
635         case Expression::Kind::kIndex:
636             this->writeIndexExpression(expr.as<IndexExpression>());
637             break;
638         default:
639             SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
640             break;
641     }
642 }
643 
writeAnyConstructor(const AnyConstructor & c,Precedence parentPrecedence)644 void PipelineStageCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
645                                                      Precedence parentPrecedence) {
646     this->writeType(c.type());
647     this->write("(");
648     auto separator = SkSL::String::Separator();
649     for (const auto& arg : c.argumentSpan()) {
650         this->write(separator());
651         this->writeExpression(*arg, Precedence::kSequence);
652     }
653     this->write(")");
654 }
655 
writeIndexExpression(const IndexExpression & expr)656 void PipelineStageCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
657     this->writeExpression(*expr.base(), Precedence::kPostfix);
658     this->write("[");
659     this->writeExpression(*expr.index(), Precedence::kExpression);
660     this->write("]");
661 }
662 
writeFieldAccess(const FieldAccess & f)663 void PipelineStageCodeGenerator::writeFieldAccess(const FieldAccess& f) {
664     if (f.ownerKind() == FieldAccess::OwnerKind::kDefault) {
665         this->writeExpression(*f.base(), Precedence::kPostfix);
666         this->write(".");
667     }
668     const Type& baseType = f.base()->type();
669     this->write(baseType.fields()[f.fieldIndex()].fName);
670 }
671 
writeSwizzle(const Swizzle & swizzle)672 void PipelineStageCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
673     this->writeExpression(*swizzle.base(), Precedence::kPostfix);
674     this->write(".");
675     this->write(Swizzle::MaskString(swizzle.components()));
676 }
677 
writeBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)678 void PipelineStageCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
679                                                        Precedence parentPrecedence) {
680     const Expression& left = *b.left();
681     const Expression& right = *b.right();
682     Operator op = b.getOperator();
683 
684     Precedence precedence = op.getBinaryPrecedence();
685     if (precedence >= parentPrecedence) {
686         this->write("(");
687     }
688     this->writeExpression(left, precedence);
689     this->write(op.operatorName());
690     this->writeExpression(right, precedence);
691     if (precedence >= parentPrecedence) {
692         this->write(")");
693     }
694 }
695 
writeTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)696 void PipelineStageCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
697                                                         Precedence parentPrecedence) {
698     if (Precedence::kTernary >= parentPrecedence) {
699         this->write("(");
700     }
701     this->writeExpression(*t.test(), Precedence::kTernary);
702     this->write(" ? ");
703     this->writeExpression(*t.ifTrue(), Precedence::kTernary);
704     this->write(" : ");
705     this->writeExpression(*t.ifFalse(), Precedence::kTernary);
706     if (Precedence::kTernary >= parentPrecedence) {
707         this->write(")");
708     }
709 }
710 
writePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)711 void PipelineStageCodeGenerator::writePrefixExpression(const PrefixExpression& p,
712                                                        Precedence parentPrecedence) {
713     if (Precedence::kPrefix >= parentPrecedence) {
714         this->write("(");
715     }
716     this->write(p.getOperator().tightOperatorName());
717     this->writeExpression(*p.operand(), Precedence::kPrefix);
718     if (Precedence::kPrefix >= parentPrecedence) {
719         this->write(")");
720     }
721 }
722 
writePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)723 void PipelineStageCodeGenerator::writePostfixExpression(const PostfixExpression& p,
724                                                         Precedence parentPrecedence) {
725     if (Precedence::kPostfix >= parentPrecedence) {
726         this->write("(");
727     }
728     this->writeExpression(*p.operand(), Precedence::kPostfix);
729     this->write(p.getOperator().tightOperatorName());
730     if (Precedence::kPostfix >= parentPrecedence) {
731         this->write(")");
732     }
733 }
734 
modifierString(ModifierFlags flags)735 std::string PipelineStageCodeGenerator::modifierString(ModifierFlags flags) {
736     std::string result;
737     if (flags.isConst()) {
738         result.append("const ");
739     }
740     if ((flags & ModifierFlag::kIn) && (flags & ModifierFlag::kOut)) {
741         result.append("inout ");
742     } else if (flags & ModifierFlag::kIn) {
743         result.append("in ");
744     } else if (flags & ModifierFlag::kOut) {
745         result.append("out ");
746     }
747 
748     return result;
749 }
750 
typedVariable(const Type & type,std::string_view name)751 std::string PipelineStageCodeGenerator::typedVariable(const Type& type, std::string_view name) {
752     const Type& baseType = type.isArray() ? type.componentType() : type;
753 
754     std::string decl = this->typeName(baseType) + " " + std::string(name);
755     if (type.isArray()) {
756         decl += "[" + std::to_string(type.columns()) + "]";
757     }
758     return decl;
759 }
760 
writeVarDeclaration(const VarDeclaration & var)761 void PipelineStageCodeGenerator::writeVarDeclaration(const VarDeclaration& var) {
762     this->write(this->modifierString(var.var()->modifierFlags()));
763     this->write(this->typedVariable(var.var()->type(), var.var()->name()));
764     if (var.value()) {
765         this->write(" = ");
766         this->writeExpression(*var.value(), Precedence::kExpression);
767     }
768     this->write(";");
769 }
770 
writeStatement(const Statement & s)771 void PipelineStageCodeGenerator::writeStatement(const Statement& s) {
772     switch (s.kind()) {
773         case Statement::Kind::kBlock:
774             this->writeBlock(s.as<Block>());
775             break;
776         case Statement::Kind::kBreak:
777             this->write("break;");
778             break;
779         case Statement::Kind::kContinue:
780             this->write("continue;");
781             break;
782         case Statement::Kind::kExpression:
783             this->writeExpression(*s.as<ExpressionStatement>().expression(),
784                                   Precedence::kStatement);
785             this->write(";");
786             break;
787         case Statement::Kind::kDo:
788             this->writeDoStatement(s.as<DoStatement>());
789             break;
790         case Statement::Kind::kFor:
791             this->writeForStatement(s.as<ForStatement>());
792             break;
793         case Statement::Kind::kIf:
794             this->writeIfStatement(s.as<IfStatement>());
795             break;
796         case Statement::Kind::kReturn:
797             this->writeReturnStatement(s.as<ReturnStatement>());
798             break;
799         case Statement::Kind::kSwitch:
800             this->writeSwitchStatement(s.as<SwitchStatement>());
801             break;
802         case Statement::Kind::kVarDeclaration:
803             this->writeVarDeclaration(s.as<VarDeclaration>());
804             break;
805         case Statement::Kind::kDiscard:
806             SkDEBUGFAIL("Unsupported control flow");
807             break;
808         case Statement::Kind::kNop:
809             this->write(";");
810             break;
811         default:
812             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
813             break;
814     }
815 }
816 
writeBlock(const Block & b)817 void PipelineStageCodeGenerator::writeBlock(const Block& b) {
818     // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
819     // something here to make the code valid).
820     bool isScope = b.isScope() || b.isEmpty();
821     if (isScope) {
822         this->writeLine("{");
823     }
824     for (const std::unique_ptr<Statement>& stmt : b.children()) {
825         if (!stmt->isEmpty()) {
826             this->writeStatement(*stmt);
827             this->writeLine();
828         }
829     }
830     if (isScope) {
831         this->write("}");
832     }
833 }
834 
writeDoStatement(const DoStatement & d)835 void PipelineStageCodeGenerator::writeDoStatement(const DoStatement& d) {
836     this->write("do ");
837     this->writeStatement(*d.statement());
838     this->write(" while (");
839     this->writeExpression(*d.test(), Precedence::kExpression);
840     this->write(");");
841 }
842 
writeForStatement(const ForStatement & f)843 void PipelineStageCodeGenerator::writeForStatement(const ForStatement& f) {
844     // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
845     if (!f.initializer() && f.test() && !f.next()) {
846         this->write("while (");
847         this->writeExpression(*f.test(), Precedence::kExpression);
848         this->write(") ");
849         this->writeStatement(*f.statement());
850         return;
851     }
852 
853     this->write("for (");
854     if (f.initializer() && !f.initializer()->isEmpty()) {
855         this->writeStatement(*f.initializer());
856     } else {
857         this->write("; ");
858     }
859     if (f.test()) {
860         this->writeExpression(*f.test(), Precedence::kExpression);
861     }
862     this->write("; ");
863     if (f.next()) {
864         this->writeExpression(*f.next(), Precedence::kExpression);
865     }
866     this->write(") ");
867     this->writeStatement(*f.statement());
868 }
869 
generateCode()870 void PipelineStageCodeGenerator::generateCode() {
871     // Search for functions which require specialization due to passing child effects as parameters.
872     Analysis::FindFunctionsToSpecialize(fProgram, &fSpecializationInfo, [](const Variable& param) {
873         return param.type().isEffectChild();
874     });
875 
876     // Write all the program elements except for functions; prototype all the functions.
877     for (const ProgramElement* e : fProgram.elements()) {
878         this->writeProgramElementFirstPass(*e);
879     }
880 
881     // We always place FunctionDefinition elements last, because the inliner likes to move function
882     // bodies around. After inlining, code can inadvertently move upwards, above ProgramElements
883     // that the code relies on.
884     for (const ProgramElement* e : fProgram.elements()) {
885         this->writeProgramElementSecondPass(*e);
886     }
887 }
888 
ConvertProgram(const Program & program,const char * sampleCoords,const char * inputColor,const char * destColor,Callbacks * callbacks)889 void ConvertProgram(const Program& program,
890                     const char* sampleCoords,
891                     const char* inputColor,
892                     const char* destColor,
893                     Callbacks* callbacks) {
894     PipelineStageCodeGenerator generator(program, sampleCoords, inputColor, destColor, callbacks);
895     generator.generateCode();
896 }
897 
898 }  // namespace PipelineStage
899 }  // namespace SkSL
900