1 /* 2 * Copyright 2024 Google LLC 3 * 4 * Use of this source code is governed by a BSD-style license that can be 5 * found in the LICENSE file. 6 */ 7 8 #include "src/sksl/SkSLModule.h" 9 #include "src/sksl/ir/SkSLConstructorSplat.h" 10 #include "src/sksl/ir/SkSLExpression.h" 11 #include "src/sksl/ir/SkSLFunctionDefinition.h" 12 #include "src/sksl/ir/SkSLLiteral.h" 13 #include "src/sksl/ir/SkSLProgramElement.h" 14 #include "src/sksl/ir/SkSLSwizzle.h" 15 #include "src/sksl/ir/SkSLType.h" 16 #include "src/sksl/transform/SkSLProgramWriter.h" 17 #include "src/sksl/transform/SkSLTransform.h" 18 19 #include <memory> 20 #include <utility> 21 #include <vector> 22 23 namespace SkSL { 24 25 class Context; 26 ReplaceSplatCastsWithSwizzles(const Context & context,Module & module)27void Transform::ReplaceSplatCastsWithSwizzles(const Context& context, Module& module) { 28 class SwizzleWriter : public ProgramWriter { 29 public: 30 SwizzleWriter(const Context& ctx) : fContext(ctx) {} 31 32 bool visitExpressionPtr(std::unique_ptr<Expression>& expr) override { 33 if (INHERITED::visitExpressionPtr(expr)) { 34 return true; 35 } 36 if (expr->is<ConstructorSplat>()) { 37 // If the argument is a literal, only allow floats. The swizzled-literal syntax only 38 // works properly for floats. 39 std::unique_ptr<Expression>& arg = expr->as<ConstructorSplat>().argument(); 40 if (!arg->is<Literal>() || (arg->type().isFloat() && arg->type().highPrecision())) { 41 // Synthesize a splat like `.xxxx`, matching the column count of the splat. 42 ComponentArray components; 43 components.push_back_n(expr->type().columns(), SwizzleComponent::X); 44 45 // Replace the splat expression with the swizzle. 46 expr = Swizzle::MakeExact(fContext, expr->position(), std::move(arg), 47 std::move(components)); 48 } 49 } 50 return false; 51 } 52 53 const Context& fContext; 54 55 using INHERITED = ProgramWriter; 56 }; 57 58 for (std::unique_ptr<ProgramElement>& pe : module.fElements) { 59 if (pe->is<FunctionDefinition>()) { 60 SwizzleWriter writer{context}; 61 writer.visitStatementPtr(pe->as<FunctionDefinition>().body()); 62 } 63 } 64 } 65 66 } // namespace SkSL 67