1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "include/core/SkTypes.h"
9 #include "src/core/SkTHash.h"
10 #include "src/sksl/SkSLAnalysis.h"
11 #include "src/sksl/SkSLContext.h"
12 #include "src/sksl/SkSLErrorReporter.h"
13 #include "src/sksl/analysis/SkSLProgramVisitor.h"
14 #include "src/sksl/ir/SkSLExpression.h"
15 #include "src/sksl/ir/SkSLFunctionCall.h"
16 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
17 #include "src/sksl/ir/SkSLFunctionDefinition.h"
18 #include "src/sksl/ir/SkSLProgram.h"
19 #include "src/sksl/ir/SkSLProgramElement.h"
20
21 #include <cstddef>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 namespace SkSL {
28
CheckProgramStructure(const Program & program)29 bool Analysis::CheckProgramStructure(const Program& program) {
30 const Context& context = *program.fContext;
31
32 static constexpr size_t kProgramStackDepthLimit = 50;
33
34 class ProgramStructureVisitor : public ProgramVisitor {
35 public:
36 ProgramStructureVisitor(const Context& c) : fContext(c) {}
37
38 using ProgramVisitor::visitProgramElement;
39
40 bool visitProgramElement(const ProgramElement& pe) override {
41 if (pe.is<FunctionDefinition>()) {
42 // Check the function map first. We don't need to visit this function if we already
43 // processed it before.
44 const FunctionDeclaration* decl = &pe.as<FunctionDefinition>().declaration();
45 if (FunctionState *funcState = fFunctionMap.find(decl)) {
46 // We already have this function in our map. We don't need to check it again.
47 if (*funcState == FunctionState::kVisiting) {
48 // If the function is present in the map with with the `kVisiting` state,
49 // we're recursively processing it -- in other words, we found a cycle in
50 // the code. Unwind our stack into a string.
51 std::string msg = "\n\t" + decl->description();
52 for (auto unwind = fStack.rbegin(); unwind != fStack.rend(); ++unwind) {
53 msg = "\n\t" + (*unwind)->description() + msg;
54 if (*unwind == decl) {
55 break;
56 }
57 }
58 msg = "potential recursion (function call cycle) not allowed:" + msg;
59 fContext.fErrors->error(pe.fPosition, std::move(msg));
60 *funcState = FunctionState::kVisited;
61 return true;
62 }
63 return false;
64 }
65
66 // If the function-call stack has gotten too deep, stop the analysis.
67 if (fStack.size() >= kProgramStackDepthLimit) {
68 std::string msg = "exceeded max function call depth:";
69 for (auto unwind = fStack.begin(); unwind != fStack.end(); ++unwind) {
70 msg += "\n\t" + (*unwind)->description();
71 }
72 msg += "\n\t" + decl->description();
73 fContext.fErrors->error(pe.fPosition, std::move(msg));
74 fFunctionMap.set(decl, FunctionState::kVisited);
75 return true;
76 }
77
78 fFunctionMap.set(decl, FunctionState::kVisiting);
79 fStack.push_back(decl);
80 bool result = INHERITED::visitProgramElement(pe);
81 fFunctionMap.set(decl, FunctionState::kVisited);
82 fStack.pop_back();
83
84 return result;
85 }
86
87 return INHERITED::visitProgramElement(pe);
88 }
89
90 bool visitExpression(const Expression& expr) override {
91 bool earlyExit = false;
92
93 if (expr.is<FunctionCall>()) {
94 const FunctionCall& call = expr.as<FunctionCall>();
95 const FunctionDeclaration* decl = &call.function();
96 if (decl->definition() && !decl->isIntrinsic()) {
97 earlyExit = this->visitProgramElement(*decl->definition());
98 }
99 }
100
101 return earlyExit || INHERITED::visitExpression(expr);
102 }
103
104 private:
105 using INHERITED = ProgramVisitor;
106
107 enum class FunctionState {
108 kVisiting,
109 kVisited,
110 };
111
112 const Context& fContext;
113 skia_private::THashMap<const FunctionDeclaration*, FunctionState> fFunctionMap;
114 std::vector<const FunctionDeclaration*> fStack;
115 };
116
117 // Process every function in our program.
118 ProgramStructureVisitor visitor{context};
119 for (const std::unique_ptr<ProgramElement>& element : program.fOwnedElements) {
120 if (element->is<FunctionDefinition>()) {
121 // Visit every function--we want to detect static recursion and report it as an error,
122 // even in unreferenced functions.
123 visitor.visitProgramElement(*element);
124 }
125 }
126
127 return true;
128 }
129
130 } // namespace SkSL
131