xref: /aosp_15_r20/external/skia/src/sksl/transform/SkSLReplaceSplatCastsWithSwizzles.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2024 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/SkSLModule.h"
9 #include "src/sksl/ir/SkSLConstructorSplat.h"
10 #include "src/sksl/ir/SkSLExpression.h"
11 #include "src/sksl/ir/SkSLFunctionDefinition.h"
12 #include "src/sksl/ir/SkSLLiteral.h"
13 #include "src/sksl/ir/SkSLProgramElement.h"
14 #include "src/sksl/ir/SkSLSwizzle.h"
15 #include "src/sksl/ir/SkSLType.h"
16 #include "src/sksl/transform/SkSLProgramWriter.h"
17 #include "src/sksl/transform/SkSLTransform.h"
18 
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 namespace SkSL {
24 
25 class Context;
26 
ReplaceSplatCastsWithSwizzles(const Context & context,Module & module)27 void Transform::ReplaceSplatCastsWithSwizzles(const Context& context, Module& module) {
28     class SwizzleWriter : public ProgramWriter {
29     public:
30         SwizzleWriter(const Context& ctx) : fContext(ctx) {}
31 
32         bool visitExpressionPtr(std::unique_ptr<Expression>& expr) override {
33             if (INHERITED::visitExpressionPtr(expr)) {
34                 return true;
35             }
36             if (expr->is<ConstructorSplat>()) {
37                 // If the argument is a literal, only allow floats. The swizzled-literal syntax only
38                 // works properly for floats.
39                 std::unique_ptr<Expression>& arg = expr->as<ConstructorSplat>().argument();
40                 if (!arg->is<Literal>() || (arg->type().isFloat() && arg->type().highPrecision())) {
41                     // Synthesize a splat like `.xxxx`, matching the column count of the splat.
42                     ComponentArray components;
43                     components.push_back_n(expr->type().columns(), SwizzleComponent::X);
44 
45                     // Replace the splat expression with the swizzle.
46                     expr = Swizzle::MakeExact(fContext, expr->position(), std::move(arg),
47                                               std::move(components));
48                 }
49             }
50             return false;
51         }
52 
53         const Context& fContext;
54 
55         using INHERITED = ProgramWriter;
56     };
57 
58     for (std::unique_ptr<ProgramElement>& pe : module.fElements) {
59         if (pe->is<FunctionDefinition>()) {
60             SwizzleWriter writer{context};
61             writer.visitStatementPtr(pe->as<FunctionDefinition>().body());
62         }
63     }
64 }
65 
66 }  // namespace SkSL
67