1 #pragma once
2
3 #include <functional>
4 #include <memory>
5 #include <unordered_map>
6 #include <vector>
7
8 #include <c10/util/SmallVector.h>
9 #include <c10/util/intrusive_ptr.h>
10 #include <torch/csrc/jit/frontend/lexer.h>
11
12 namespace torch::jit {
13
14 // Trees are used to represent all forms of TC IR, pre- and post-typechecking.
15 // Rather than have a full class hierarchy for all TC statements, trees are a
16 // slight variation of Lisp s-expressions. For instance, the expression a*b+1
17 // is represented as:
18 // (+ (* (ident a) (ident b)) (const 1))
19 // Atoms like 'a', 'b', and '1' are represented by subclasses of Tree which
20 // define stringValue(). Everything else is a Compound object, which has a
21 // 'kind' that is a token from lexer.h's TokenKind enum. Single-character
22 // operators like '+' are represented using the character itself (so, add.kind()
23 // would be '+'). Each Compound object also contains a list of subtrees and is
24 // associated with a SourceRange for error reporting.
25 // Memory management of trees is done using intrusive_ptr.
26
27 struct Tree;
28 using TreeRef = c10::intrusive_ptr<Tree>;
29 using TreeList = at::SmallVector<TreeRef, 4>;
30
31 struct Tree : c10::intrusive_ptr_target {
TreeTree32 Tree(int kind_) : kind_(kind_) {}
kindTree33 int kind() const {
34 return kind_;
35 }
isAtomTree36 virtual bool isAtom() const {
37 return true;
38 }
rangeTree39 virtual const SourceRange& range() const {
40 throw std::runtime_error("is an Atom");
41 }
stringValueTree42 virtual const std::string& stringValue() const {
43 throw std::runtime_error("stringValue can only be called on TK_STRING");
44 }
treesTree45 virtual const TreeList& trees() const {
46 static const TreeList empty_trees = {};
47 return empty_trees;
48 }
treeTree49 const TreeRef& tree(size_t i) const {
50 return trees().at(i);
51 }
mapTree52 virtual TreeRef map(const std::function<TreeRef(TreeRef)>& fn) {
53 (void)fn;
54 c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
55 // from a raw `this` pointer
56 // so we need to bump the refcount
57 // to account for this ownership
58 return TreeRef::reclaim(this);
59 }
60 template <typename... Args>
matchTree61 void match(int k, Args&... args) const {
62 matchD(k, "unknown", 0, args...);
63 }
64 template <typename... Args>
matchDTree65 void matchD(int k, const char* filename, int lineno, Args&... args) const {
66 std::initializer_list<TreeRef*> vars = {args...};
67 matchNumSubtreesD(k, filename, lineno, vars.size(), true);
68 size_t i = 0;
69 for (TreeRef* v : vars) {
70 *v = trees()[i++];
71 }
72 }
matchNumSubtreesTree73 void matchNumSubtrees(int k, size_t expected_subtrees) {
74 return matchNumSubtreesD(k, "unknown", 0, expected_subtrees, false);
75 }
matchNumSubtreesDTree76 void matchNumSubtreesD(
77 int k,
78 const char* filename,
79 int lineno,
80 size_t expected_subtrees,
81 bool allow_more) const {
82 if (kind() != k) {
83 std::stringstream ss;
84 ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
85 << "' but found '" << kindToString(kind()) << "'\n";
86 range().highlight(ss);
87 throw std::runtime_error(ss.str());
88 }
89 if (trees().size() < expected_subtrees ||
90 (!allow_more && trees().size() != expected_subtrees)) {
91 std::stringstream ss;
92 ss << filename << ":" << lineno << ": expected at least "
93 << expected_subtrees << " subtrees, but found only " << trees().size()
94 << "\n";
95 range().highlight(ss);
96 throw std::runtime_error(ss.str());
97 }
98 }
99 ~Tree() override = default;
100
101 private:
102 int kind_;
103 };
104
105 struct String : public Tree {
StringString106 String(std::string value) : Tree(TK_STRING), value_(std::move(value)) {}
stringValueString107 const std::string& stringValue() const override {
108 return value_;
109 }
110 template <typename... Args>
createString111 static TreeRef create(Args&&... args) {
112 return c10::make_intrusive<String>(std::forward<Args>(args)...);
113 }
114
115 private:
116 std::string value_;
117 };
118
mergeRanges(SourceRange c,const TreeList & others)119 static SourceRange mergeRanges(SourceRange c, const TreeList& others) {
120 for (const auto& t : others) {
121 if (t->isAtom())
122 continue;
123 size_t s = std::min(c.start(), t->range().start());
124 size_t e = std::max(c.end(), t->range().end());
125 c = SourceRange(c.source(), s, e);
126 }
127 return c;
128 }
129
130 struct Compound : public Tree {
CompoundCompound131 Compound(int kind, SourceRange range)
132 : Tree(kind), range_(std::move(range)) {}
CompoundCompound133 Compound(int kind, const SourceRange& range_, TreeList&& trees_)
134 : Tree(kind),
135 range_(mergeRanges(range_, trees_)),
136 trees_(std::move(trees_)) {}
treesCompound137 const TreeList& trees() const override {
138 return trees_;
139 }
createCompound140 static TreeRef create(
141 int kind,
142 const SourceRange& range_,
143 TreeList&& trees_) {
144 return c10::make_intrusive<Compound>(kind, range_, std::move(trees_));
145 }
isAtomCompound146 bool isAtom() const override {
147 return false;
148 }
mapCompound149 TreeRef map(const std::function<TreeRef(TreeRef)>& fn) override {
150 TreeList ret;
151 for (auto& t : trees()) {
152 ret.push_back(fn(t));
153 }
154 return Compound::create(kind(), range(), std::move(ret));
155 }
156
rangeCompound157 const SourceRange& range() const override {
158 return range_;
159 }
160
161 private:
162 SourceRange range_;
163 TreeList trees_;
164 };
165
166 // tree pretty printer
167 struct pretty_tree {
treepretty_tree168 pretty_tree(const TreeRef& tree, size_t col = 40) : tree(tree), col(col) {}
169 const TreeRef& tree;
170 size_t col;
171 std::unordered_map<TreeRef, std::string> flat_strings;
get_flatpretty_tree172 const std::string& get_flat(const TreeRef& t) {
173 auto it = flat_strings.find(t);
174 if (it != flat_strings.end())
175 return it->second;
176
177 std::stringstream out;
178 switch (t->kind()) {
179 case TK_STRING:
180 out << t->stringValue();
181 break;
182 default:
183 out << "(" << kindToString(t->kind());
184 for (const auto& e : t->trees()) {
185 out << " " << get_flat(e);
186 }
187 out << ")";
188 break;
189 }
190 auto it_ = flat_strings.emplace(t, out.str());
191 return it_.first->second;
192 }
printpretty_tree193 void print(std::ostream& out, const TreeRef& t, int indent) {
194 const std::string& s = get_flat(t);
195 if (indent + s.size() < col || t->isAtom()) {
196 out << s;
197 return;
198 }
199 std::string k = kindToString(t->kind());
200 out << "(" << k;
201 for (const auto& e : t->trees()) {
202 out << "\n" << std::string(indent + 2, ' ');
203 print(out, e, indent + 2);
204 }
205 out << ")";
206 }
207 };
208
209 static inline std::ostream& operator<<(std::ostream& out, pretty_tree t_) {
210 t_.print(out, t_.tree, 0);
211 return out << '\n';
212 }
213
214 static inline std::ostream& operator<<(std::ostream& out, const TreeRef& t) {
215 return out << pretty_tree(t);
216 }
217
218 } // namespace torch::jit
219