xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/tree.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <functional>
4 #include <memory>
5 #include <unordered_map>
6 #include <vector>
7 
8 #include <c10/util/SmallVector.h>
9 #include <c10/util/intrusive_ptr.h>
10 #include <torch/csrc/jit/frontend/lexer.h>
11 
12 namespace torch::jit {
13 
14 // Trees are used to represent all forms of TC IR, pre- and post-typechecking.
15 // Rather than have a full class hierarchy for all TC statements, trees are a
16 // slight variation of Lisp s-expressions. For instance, the expression a*b+1
17 // is represented as:
18 // (+ (* (ident a) (ident b)) (const 1))
19 // Atoms like 'a', 'b', and '1' are represented by subclasses of Tree which
20 // define stringValue(). Everything else is a Compound object, which has a
21 // 'kind' that is a token from lexer.h's TokenKind enum. Single-character
22 // operators like '+' are represented using the character itself (so, add.kind()
23 // would be '+'). Each Compound object also contains a list of subtrees and is
24 // associated with a SourceRange for error reporting.
25 // Memory management of trees is done using intrusive_ptr.
26 
27 struct Tree;
28 using TreeRef = c10::intrusive_ptr<Tree>;
29 using TreeList = at::SmallVector<TreeRef, 4>;
30 
31 struct Tree : c10::intrusive_ptr_target {
TreeTree32   Tree(int kind_) : kind_(kind_) {}
kindTree33   int kind() const {
34     return kind_;
35   }
isAtomTree36   virtual bool isAtom() const {
37     return true;
38   }
rangeTree39   virtual const SourceRange& range() const {
40     throw std::runtime_error("is an Atom");
41   }
stringValueTree42   virtual const std::string& stringValue() const {
43     throw std::runtime_error("stringValue can only be called on TK_STRING");
44   }
treesTree45   virtual const TreeList& trees() const {
46     static const TreeList empty_trees = {};
47     return empty_trees;
48   }
treeTree49   const TreeRef& tree(size_t i) const {
50     return trees().at(i);
51   }
mapTree52   virtual TreeRef map(const std::function<TreeRef(TreeRef)>& fn) {
53     (void)fn;
54     c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
55                                            // from a raw `this` pointer
56                                            // so we need to bump the refcount
57                                            // to account for this ownership
58     return TreeRef::reclaim(this);
59   }
60   template <typename... Args>
matchTree61   void match(int k, Args&... args) const {
62     matchD(k, "unknown", 0, args...);
63   }
64   template <typename... Args>
matchDTree65   void matchD(int k, const char* filename, int lineno, Args&... args) const {
66     std::initializer_list<TreeRef*> vars = {args...};
67     matchNumSubtreesD(k, filename, lineno, vars.size(), true);
68     size_t i = 0;
69     for (TreeRef* v : vars) {
70       *v = trees()[i++];
71     }
72   }
matchNumSubtreesTree73   void matchNumSubtrees(int k, size_t expected_subtrees) {
74     return matchNumSubtreesD(k, "unknown", 0, expected_subtrees, false);
75   }
matchNumSubtreesDTree76   void matchNumSubtreesD(
77       int k,
78       const char* filename,
79       int lineno,
80       size_t expected_subtrees,
81       bool allow_more) const {
82     if (kind() != k) {
83       std::stringstream ss;
84       ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
85          << "' but found '" << kindToString(kind()) << "'\n";
86       range().highlight(ss);
87       throw std::runtime_error(ss.str());
88     }
89     if (trees().size() < expected_subtrees ||
90         (!allow_more && trees().size() != expected_subtrees)) {
91       std::stringstream ss;
92       ss << filename << ":" << lineno << ": expected at least "
93          << expected_subtrees << " subtrees, but found only " << trees().size()
94          << "\n";
95       range().highlight(ss);
96       throw std::runtime_error(ss.str());
97     }
98   }
99   ~Tree() override = default;
100 
101  private:
102   int kind_;
103 };
104 
105 struct String : public Tree {
StringString106   String(std::string value) : Tree(TK_STRING), value_(std::move(value)) {}
stringValueString107   const std::string& stringValue() const override {
108     return value_;
109   }
110   template <typename... Args>
createString111   static TreeRef create(Args&&... args) {
112     return c10::make_intrusive<String>(std::forward<Args>(args)...);
113   }
114 
115  private:
116   std::string value_;
117 };
118 
mergeRanges(SourceRange c,const TreeList & others)119 static SourceRange mergeRanges(SourceRange c, const TreeList& others) {
120   for (const auto& t : others) {
121     if (t->isAtom())
122       continue;
123     size_t s = std::min(c.start(), t->range().start());
124     size_t e = std::max(c.end(), t->range().end());
125     c = SourceRange(c.source(), s, e);
126   }
127   return c;
128 }
129 
130 struct Compound : public Tree {
CompoundCompound131   Compound(int kind, SourceRange range)
132       : Tree(kind), range_(std::move(range)) {}
CompoundCompound133   Compound(int kind, const SourceRange& range_, TreeList&& trees_)
134       : Tree(kind),
135         range_(mergeRanges(range_, trees_)),
136         trees_(std::move(trees_)) {}
treesCompound137   const TreeList& trees() const override {
138     return trees_;
139   }
createCompound140   static TreeRef create(
141       int kind,
142       const SourceRange& range_,
143       TreeList&& trees_) {
144     return c10::make_intrusive<Compound>(kind, range_, std::move(trees_));
145   }
isAtomCompound146   bool isAtom() const override {
147     return false;
148   }
mapCompound149   TreeRef map(const std::function<TreeRef(TreeRef)>& fn) override {
150     TreeList ret;
151     for (auto& t : trees()) {
152       ret.push_back(fn(t));
153     }
154     return Compound::create(kind(), range(), std::move(ret));
155   }
156 
rangeCompound157   const SourceRange& range() const override {
158     return range_;
159   }
160 
161  private:
162   SourceRange range_;
163   TreeList trees_;
164 };
165 
166 // tree pretty printer
167 struct pretty_tree {
treepretty_tree168   pretty_tree(const TreeRef& tree, size_t col = 40) : tree(tree), col(col) {}
169   const TreeRef& tree;
170   size_t col;
171   std::unordered_map<TreeRef, std::string> flat_strings;
get_flatpretty_tree172   const std::string& get_flat(const TreeRef& t) {
173     auto it = flat_strings.find(t);
174     if (it != flat_strings.end())
175       return it->second;
176 
177     std::stringstream out;
178     switch (t->kind()) {
179       case TK_STRING:
180         out << t->stringValue();
181         break;
182       default:
183         out << "(" << kindToString(t->kind());
184         for (const auto& e : t->trees()) {
185           out << " " << get_flat(e);
186         }
187         out << ")";
188         break;
189     }
190     auto it_ = flat_strings.emplace(t, out.str());
191     return it_.first->second;
192   }
printpretty_tree193   void print(std::ostream& out, const TreeRef& t, int indent) {
194     const std::string& s = get_flat(t);
195     if (indent + s.size() < col || t->isAtom()) {
196       out << s;
197       return;
198     }
199     std::string k = kindToString(t->kind());
200     out << "(" << k;
201     for (const auto& e : t->trees()) {
202       out << "\n" << std::string(indent + 2, ' ');
203       print(out, e, indent + 2);
204     }
205     out << ")";
206   }
207 };
208 
209 static inline std::ostream& operator<<(std::ostream& out, pretty_tree t_) {
210   t_.print(out, t_.tree, 0);
211   return out << '\n';
212 }
213 
214 static inline std::ostream& operator<<(std::ostream& out, const TreeRef& t) {
215   return out << pretty_tree(t);
216 }
217 
218 } // namespace torch::jit
219