1 #pragma once
2 #include <torch/csrc/jit/frontend/error_report.h>
3 #include <torch/csrc/jit/frontend/strtod.h>
4 #include <torch/csrc/jit/frontend/tree.h>
5
6 #include <c10/util/complex.h>
7 #include <functional>
8 #include <iostream>
9 #include <string>
10 #include <utility>
11
12 namespace torch::jit {
13
14 // clang-format off
15 // TreeView provides a statically-typed way to traverse the tree, which should
16 // be formed according to the grammar below.
17 //
18 // A few notes on types and their aliases:
19 // - List<T> is really a Tree with kind TK_LIST and elements as subtrees
20 // - Maybe<T> is really a Tree with kind TK_OPTION that has 0 or 1 subtree of type T
21 // - Builtin types are: Ident (TK_IDENT), String (TK_STRING)
22 //
23 // Param = Param(Maybe<Expr> type, Ident name) TK_PARAM
24 //
25 // Decl = Decl(List<Param> params, Maybe<Expr> return_type) TK_DECL
26 // Def = Def(Ident name, Decl decl, List<Stmt> body) TK_DEF
27 // ClassDef = ClassDef(Ident name, TK_CLASS_DEF
28 // Maybe<Expr> superclass,
29 // List<Stmt> body)
30 //
31 // Stmt = If(Expr cond, List<Stmt> true_body, List<Stmt> false_body) TK_IF
32 // | For(List<Expr> targets, List<Expr> iters, List<Stmt> body) TK_FOR
33 // | While(Expr cond, List<Stmt> body) TK_WHILE
34 // | Global(List<Ident> idents) TK_GLOBAL
35 // -- NB: the only type of Expr's allowed on lhs are Var
36 // Or a tuple containing Var with an optional terminating Starred
37 // | Assign(Expr lhs, Maybe<Expr> rhs, Maybe<Expr> type) TK_ASSIGN
38 // | AugAssign(Expr lhs, AugAssignKind aug_op, Expr rhs) TK_AUG_ASSIGN
39 // | Return(List<Expr> values) TK_RETURN
40 // | ExprStmt(List<Expr> expr) TK_EXPR_STMT
41 // | Raise(Expr expr) TK_RAISE
42 // | Def TK_DEF
43 // | With(List<WithItem> targets, List<Stmt> body) TK_WITH
44 //
45 // Expr = TernaryIf(Expr cond, Expr true_expr, Expr false_expr) TK_IF_EXPR
46 // | BinOp(Expr lhs, Expr rhs)
47 // | And TK_AND
48 // | Or TK_OR
49 // | Lt '<'
50 // | Gt '>'
51 // | Eq TK_EQ
52 // | Le TK_LE
53 // | Ge TK_GE
54 // | Ne TK_NE
55 // | Is TK_IS
56 // | IsNot TK_ISNOT
57 // | Add '+'
58 // | Sub '-'
59 // | Mul '*'
60 // | Div '/'
61 // | Mod '%'
62 // | MatMult '@'
63 // | Pow TK_POW
64 // | UnaryOp(Expr expr)
65 // | Not TK_NOT
66 // | USub '-'
67 // | Const(String value) TK_CONST
68 // -- NB: x.name(y) is desugared into name(x, y)
69 // | Apply(Ident name, List<Expr> args, List<Attribute> kwargs) TK_APPLY
70 // | Select(Expr value, Ident selector) '.'
71 // | Subscript(Expr value, List<Expr> subscript_exprs) TK_SUBSCRIPT
72 // | SliceExpr(Maybe<Expr> start, Maybe<Expr> end) TK_SLICE_EXPR
73 // | Var(Ident name) TK_VAR
74 // | ListLiteral(List<Expr> inputs) TK_LIST_LITERAL
75 // | TupleLiteral(List<Expr> inputs) TK_TUPLE_LITERAL
76 // | Starred(Expr expr) TK_STARRED
77 // | WithItem(Expr target, Maybe<Var> var) TK_WITH_ITEM
78 // -- NB: only allowed expressions are Const or List(Const)
79 // (List as a value, not type constructor)
80 // Attribute = Attribute(Ident name, Expr value) TK_ATTRIBUTE
81 //
82 // AugAssignKind =
83 // | Add() TK_PLUS_EQ
84 // | Sub() TK_MINUS_EQ
85 // | Mul() TK_TIMES_EQ
86 // | Div() TK_DIV_EQ
87 // | Mod() TK_MOD_EQ
88 //
89
90 // Each subclass of TreeView should provide:
91 // 1. Constructor that takes a TreeRef, and checks that it's of the right type.
92 // 2. Accessors that get underlying information out of the object. If they
93 // return subtrees, they should wrap them in appropriate views too.
94 // 3. Static method 'create' that creates the underlying TreeRef object
95 // for every TreeRef kind that has a TreeView, the parser always uses
96 // (e.g.) Ident::create rather than Compound::Create, this means that
97 // changes to the structure of Ident are always made right here rather
98 // than both in the parser and in this code.
99 // XXX: these structs should have no fields to prevent slicing when passing by value
100 // clang-format on
101 struct TreeView {
TreeViewTreeView102 explicit TreeView(TreeRef tree) : tree_(std::move(tree)) {}
treeTreeView103 TreeRef tree() const {
104 return tree_;
105 }
rangeTreeView106 const SourceRange& range() const {
107 return tree_->range();
108 }
TreeRefTreeView109 operator TreeRef() const {
110 return tree_;
111 }
getTreeView112 const TreeRef& get() const {
113 return tree_;
114 }
kindTreeView115 int kind() const {
116 return tree_->kind();
117 }
dumpTreeView118 void dump() const {
119 std::cout << tree_;
120 }
121
122 protected:
subtreeTreeView123 const TreeRef& subtree(size_t i) const {
124 return tree_->trees().at(i);
125 }
126 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
127 TreeRef tree_;
128 };
129
130 template <typename T>
131 struct ListIterator {
ListIteratorListIterator132 ListIterator(TreeList::const_iterator it) : it(it) {}
133 bool operator!=(const ListIterator& rhs) const {
134 return it != rhs.it;
135 }
136 bool operator==(const ListIterator& rhs) const {
137 return it == rhs.it;
138 }
139 T operator*() const {
140 return T(*it);
141 }
142 ListIterator& operator+=(std::ptrdiff_t n) {
143 it += n;
144 return *this;
145 }
146 ListIterator& operator++() {
147 ++it;
148 return *this;
149 }
150 ListIterator& operator--() {
151 --it;
152 return *this;
153 }
154
155 private:
156 TreeList::const_iterator it;
157 };
158
159 template <typename T>
160 struct List : public TreeView {
161 using iterator = ListIterator<T>;
162 using const_iterator = ListIterator<T>;
163
ListList164 List(const TreeRef& tree) : TreeView(tree) {
165 tree->match(TK_LIST);
166 // Iterate over list to temporarily instantiate Ts that will check the type
167 for (const T& elem : *this) {
168 (void)elem; // silence unused warning
169 }
170 }
beginList171 iterator begin() const {
172 return iterator(tree_->trees().begin());
173 }
endList174 iterator end() const {
175 return iterator(tree_->trees().end());
176 }
emptyList177 bool empty() const {
178 return tree_->trees().begin() == tree_->trees().end();
179 }
180 T operator[](size_t i) const {
181 return T(subtree(i));
182 }
mapList183 TreeRef map(const std::function<TreeRef(const T&)>& fn) {
184 return tree_->map([&](TreeRef v) { return fn(T(v)); });
185 }
createList186 static List create(const SourceRange& range, const std::vector<T>& subtrees) {
187 TreeList type_erased_sub{subtrees.begin(), subtrees.end()};
188 return List(Compound::create(TK_LIST, range, std::move(type_erased_sub)));
189 }
unsafeCreateList190 static List unsafeCreate(const SourceRange& range, TreeList&& subtrees) {
191 return List(Compound::create(TK_LIST, range, std::move(subtrees)));
192 }
sizeList193 size_t size() const {
194 return tree_->trees().size();
195 }
196 };
197
198 template <typename T>
199 struct Maybe : public TreeView {
MaybeMaybe200 explicit Maybe(const TreeRef& tree) : TreeView(tree) {
201 tree_->match(TK_OPTION);
202 if (tree_->trees().size() > 1)
203 throw(ErrorReport(tree) << "Maybe trees can have at most one subtree");
204 }
MaybeMaybe205 /* implicit */ Maybe(const T& tree) : TreeView(tree) {}
presentMaybe206 bool present() const {
207 return tree_->trees().size() > 0;
208 }
getMaybe209 T get() const {
210 return T(tree_->trees().at(0));
211 }
mapMaybe212 TreeRef map(const std::function<TreeRef(const T&)>& fn) {
213 return tree_->map([&](TreeRef v) { return fn(T(v)); });
214 }
createMaybe215 static Maybe<T> create(const SourceRange& range) {
216 return Maybe<T>(Compound::create(TK_OPTION, range, {}));
217 }
createMaybe218 static Maybe<T> create(const SourceRange& range, const T& value) {
219 return Maybe<T>(Compound::create(TK_OPTION, range, {value}));
220 }
221 };
222
223 struct Ident : public TreeView {
IdentIdent224 explicit Ident(const TreeRef& tree) : TreeView(tree) {
225 tree_->match(TK_IDENT);
226 }
nameIdent227 const std::string& name() const {
228 return subtree(0)->stringValue();
229 }
createIdent230 static Ident create(const SourceRange& range, std::string name) {
231 return Ident(
232 Compound::create(TK_IDENT, range, {String::create(std::move(name))}));
233 }
234 };
235
236 ////////////////////////////////////////////////////////////////////////////////
237 // Base types (production LHS)
238 ////////////////////////////////////////////////////////////////////////////////
239
240 struct Stmt : public TreeView {
StmtStmt241 explicit Stmt(const TreeRef& tree) : TreeView(tree) {
242 switch (tree->kind()) {
243 case TK_IF:
244 case TK_FOR:
245 case TK_WHILE:
246 case TK_GLOBAL:
247 case TK_ASSIGN:
248 case TK_AUG_ASSIGN:
249 case TK_RETURN:
250 case TK_EXPR_STMT:
251 case TK_RAISE:
252 case TK_ASSERT:
253 case TK_PASS:
254 case TK_BREAK:
255 case TK_DELETE:
256 case TK_CONTINUE:
257 case TK_DEF:
258 case TK_WITH:
259 return;
260 default:
261 throw(
262 ErrorReport(tree)
263 << kindToString(tree->kind()) << " is not a valid Stmt");
264 }
265 }
266 };
267
268 struct Expr : public TreeView {
ExprExpr269 explicit Expr(const TreeRef& tree) : TreeView(tree) {
270 switch (tree->kind()) {
271 case TK_IF_EXPR:
272 case TK_AND:
273 case TK_OR:
274 case '<':
275 case '>':
276 case TK_IS:
277 case TK_ISNOT:
278 case TK_EQ:
279 case TK_LE:
280 case TK_GE:
281 case TK_NE:
282 case '+':
283 case '-':
284 case TK_UNARY_MINUS:
285 case '~':
286 case '*':
287 case TK_STARRED:
288 case '/':
289 case '%':
290 case TK_NOT:
291 case TK_CONST:
292 case TK_STRINGLITERAL:
293 case TK_TRUE:
294 case TK_FALSE:
295 case TK_NONE:
296 case TK_NONE_TYPE:
297 case TK_CAST:
298 case TK_APPLY:
299 case '.':
300 case TK_SUBSCRIPT:
301 case TK_SLICE_EXPR:
302 case TK_VAR:
303 case TK_LIST_LITERAL:
304 case TK_TUPLE_LITERAL:
305 case TK_DICT_LITERAL:
306 case '@':
307 case TK_POW:
308 case TK_LSHIFT:
309 case TK_RSHIFT:
310 case TK_FLOOR_DIV:
311 case '&':
312 case '^':
313 case '|':
314 case TK_LIST_COMP:
315 case TK_DICT_COMP:
316 case TK_DOTS:
317 case TK_IN:
318 case TK_WITH_ITEM:
319 return;
320 default:
321 throw(
322 ErrorReport(tree)
323 << kindToString(tree->kind()) << " is not a valid Expr");
324 }
325 }
326 };
327
328 ////////////////////////////////////////////////////////////////////////////////
329 // Helper nodes (mostly for function arguments)
330 ////////////////////////////////////////////////////////////////////////////////
331
332 struct Attribute : public TreeView {
AttributeAttribute333 explicit Attribute(const TreeRef& tree) : TreeView(tree) {
334 tree_->match(TK_ATTRIBUTE);
335 }
nameAttribute336 Ident name() const {
337 return Ident(subtree(0));
338 }
valueAttribute339 Expr value() const {
340 return Expr(subtree(1));
341 }
createAttribute342 static Attribute create(
343 const SourceRange& range,
344 const Ident& name,
345 const TreeRef& value) {
346 return Attribute(Compound::create(TK_ATTRIBUTE, range, {name, value}));
347 }
348 };
349
350 struct Param : public TreeView {
ParamParam351 explicit Param(const TreeRef& tree) : TreeView(tree) {
352 tree_->match(TK_PARAM);
353 }
createParam354 static Param create(
355 const SourceRange& range,
356 const Ident& ident,
357 const Maybe<Expr>& type,
358 const Maybe<Expr>& def,
359 bool kwarg_only) {
360 TreeRef kwarg_only_tree =
361 Compound::create(kwarg_only ? TK_TRUE : TK_FALSE, range, {});
362 return Param(Compound::create(
363 TK_PARAM, range, {ident, type, def, std::move(kwarg_only_tree)}));
364 }
identParam365 Ident ident() const {
366 return Ident(subtree(0));
367 }
typeParam368 Maybe<Expr> type() const {
369 return Maybe<Expr>(subtree(1));
370 }
defaultValueParam371 Maybe<Expr> defaultValue() const {
372 return Maybe<Expr>(subtree(2));
373 }
kwarg_onlyParam374 bool kwarg_only() const {
375 return TK_TRUE == subtree(3)->kind();
376 }
withTypeParam377 Param withType(const Maybe<Expr>& typ) const {
378 return Param::create(range(), ident(), typ, defaultValue(), kwarg_only());
379 }
380 };
381
382 ////////////////////////////////////////////////////////////////////////////////
383 // Top level definitions
384 ////////////////////////////////////////////////////////////////////////////////
385
386 struct Decl : public TreeView {
DeclDecl387 explicit Decl(const TreeRef& tree) : TreeView(tree) {
388 tree->match(TK_DECL);
389 }
paramsDecl390 List<Param> params() const {
391 return List<Param>(subtree(0));
392 }
return_typeDecl393 Maybe<Expr> return_type() const {
394 return Maybe<Expr>(subtree(1));
395 }
createDecl396 static Decl create(
397 const SourceRange& range,
398 const List<Param>& params,
399 const Maybe<Expr>& return_type) {
400 return Decl(Compound::create(TK_DECL, range, {params, return_type}));
401 }
402 };
403
404 struct Def : public TreeView {
DefDef405 explicit Def(const TreeRef& tree) : TreeView(tree) {
406 tree->match(TK_DEF);
407 }
withNameDef408 Def withName(std::string new_name) const {
409 auto new_ident = Ident::create(name().range(), std::move(new_name));
410 return create(range(), new_ident, decl(), statements());
411 }
withDeclDef412 Def withDecl(const Decl& decl) const {
413 return create(range(), name(), decl, statements());
414 }
nameDef415 Ident name() const {
416 return Ident(subtree(0));
417 }
declDef418 Decl decl() const {
419 return Decl(subtree(1));
420 }
statementsDef421 List<Stmt> statements() const {
422 return List<Stmt>(subtree(2));
423 }
createDef424 static Def create(
425 const SourceRange& range,
426 const Ident& name,
427 const Decl& decl,
428 const List<Stmt>& stmts) {
429 return Def(Compound::create(TK_DEF, range, {name, decl, stmts}));
430 }
431 };
432
433 // Property represents a named attribute combined with a getter and setter
434 // method to access and mutate that attribute.
435 struct Property : public TreeView {
PropertyProperty436 explicit Property(const TreeRef& tree) : TreeView(tree) {
437 tree->match(TK_PROP);
438 }
nameProperty439 Ident name() const {
440 return Ident(subtree(0));
441 }
getterProperty442 Def getter() const {
443 return Def(subtree(1));
444 }
setterProperty445 Maybe<Def> setter() const {
446 return Maybe<Def>(subtree(2));
447 }
createProperty448 static Property create(
449 const SourceRange& range,
450 const Ident& name,
451 const Def& getter,
452 const Maybe<Def>& setter) {
453 return Property(Compound::create(TK_PROP, range, {name, getter, setter}));
454 }
455 };
456
457 struct Assign;
458
459 struct ClassDef : public TreeView {
ClassDefClassDef460 explicit ClassDef(const TreeRef& tree) : TreeView(tree) {
461 tree->match(TK_CLASS_DEF);
462 }
ClassDefClassDef463 explicit ClassDef(TreeRef&& tree) : TreeView(std::move(tree)) {
464 tree_->match(TK_CLASS_DEF);
465 }
withNameClassDef466 ClassDef withName(std::string new_name) const {
467 auto new_ident = Ident::create(name().range(), std::move(new_name));
468 return create(range(), new_ident, superclass(), body());
469 }
nameClassDef470 Ident name() const {
471 return Ident(subtree(0));
472 }
superclassClassDef473 Maybe<Expr> superclass() const {
474 return Maybe<Expr>(subtree(1));
475 }
bodyClassDef476 List<Stmt> body() const {
477 return List<Stmt>(subtree(2));
478 }
propertiesClassDef479 Maybe<List<Property>> properties() const {
480 return Maybe<List<Property>>(subtree(3));
481 }
assignsClassDef482 Maybe<List<Assign>> assigns() const {
483 return Maybe<List<Assign>>(subtree(4));
484 }
createClassDef485 static ClassDef create(
486 const SourceRange& range,
487 const Ident& name,
488 const Maybe<Expr>& superclass,
489 const List<Stmt>& body) {
490 return ClassDef(Compound::create(
491 TK_CLASS_DEF,
492 range,
493 {name,
494 superclass,
495 body,
496 Maybe<List<Property>>::create(range),
497 Maybe<List<Assign>>::create(range)}));
498 }
499 static ClassDef create(
500 const SourceRange& range,
501 const Ident& name,
502 const Maybe<Expr>& superclass,
503 const List<Stmt>& body,
504 const List<Property>& properties,
505 const List<Assign>& assigns);
506 };
507
508 TORCH_API std::vector<std::string> getUnresolvedClassAttributes(
509 const ClassDef& def);
510
511 ////////////////////////////////////////////////////////////////////////////////
512 // Statements
513 ////////////////////////////////////////////////////////////////////////////////
514
515 struct If : public Stmt {
IfIf516 explicit If(const TreeRef& tree) : Stmt(tree) {
517 tree_->match(TK_IF);
518 }
condIf519 Expr cond() const {
520 return Expr(subtree(0));
521 }
trueBranchIf522 List<Stmt> trueBranch() const {
523 return List<Stmt>(subtree(1));
524 }
falseBranchIf525 List<Stmt> falseBranch() const {
526 return List<Stmt>(subtree(2));
527 }
withNewBranchesIf528 If withNewBranches(
529 const List<Stmt>& true_branch,
530 const List<Stmt>& false_branch) const {
531 return create(range(), cond(), true_branch, false_branch);
532 }
createIf533 static If create(
534 const SourceRange& range,
535 const Expr& cond,
536 const List<Stmt>& true_branch,
537 const List<Stmt>& false_branch) {
538 return If(
539 Compound::create(TK_IF, range, {cond, true_branch, false_branch}));
540 }
541 };
542
543 struct While : public Stmt {
WhileWhile544 explicit While(const TreeRef& tree) : Stmt(tree) {
545 tree_->match(TK_WHILE);
546 }
condWhile547 Expr cond() const {
548 return Expr(subtree(0));
549 }
bodyWhile550 List<Stmt> body() const {
551 return List<Stmt>(subtree(1));
552 }
createWhile553 static While create(
554 const SourceRange& range,
555 const Expr& cond,
556 const List<Stmt>& body) {
557 return While(Compound::create(TK_WHILE, range, {cond, body}));
558 }
559 };
560
561 struct For : public Stmt {
ForFor562 explicit For(const TreeRef& tree) : Stmt(tree) {
563 tree->match(TK_FOR);
564 }
targetsFor565 List<Expr> targets() const {
566 return List<Expr>(subtree(0));
567 }
itrsFor568 List<Expr> itrs() const {
569 return List<Expr>(subtree(1));
570 }
bodyFor571 List<Stmt> body() const {
572 return List<Stmt>(subtree(2));
573 }
createFor574 static For create(
575 const SourceRange& range,
576 const List<Expr>& targets,
577 const List<Expr>& itrs,
578 const List<Stmt>& body) {
579 return For(Compound::create(TK_FOR, range, {targets, itrs, body}));
580 }
581 };
582
583 // TODO: supports only single comprehension for now
584 struct ListComp : public Expr {
ListCompListComp585 explicit ListComp(const TreeRef& tree) : Expr(tree) {
586 tree->match(TK_LIST_COMP);
587 }
eltListComp588 Expr elt() const {
589 return Expr(subtree(0));
590 }
targetListComp591 Expr target() const {
592 return Expr(subtree(1));
593 }
iterListComp594 Expr iter() const {
595 return Expr(subtree(2));
596 }
597 // TODO: no ifs for now
createListComp598 static ListComp create(
599 const SourceRange& range,
600 const Expr& elt,
601 const Expr& target,
602 const Expr& iter) {
603 return ListComp(Compound::create(TK_LIST_COMP, range, {elt, target, iter}));
604 }
605 };
606
607 // TODO: supports only single comprehension for now
608 struct DictComp : public Expr {
DictCompDictComp609 explicit DictComp(const TreeRef& tree) : Expr(tree) {
610 tree->match(TK_DICT_COMP);
611 }
keyDictComp612 Expr key() const {
613 return Expr(subtree(0));
614 }
valueDictComp615 Expr value() const {
616 return Expr(subtree(1));
617 }
targetDictComp618 Expr target() const {
619 return Expr(subtree(2));
620 }
iterDictComp621 Expr iter() const {
622 return Expr(subtree(3));
623 }
624 // TODO: no ifs for now
createDictComp625 static DictComp create(
626 const SourceRange& range,
627 const Expr& key,
628 const Expr& value,
629 const Expr& target,
630 const Expr& iter) {
631 return DictComp(
632 Compound::create(TK_DICT_COMP, range, {key, value, target, iter}));
633 }
634 };
635
636 struct Global : public Stmt {
GlobalGlobal637 explicit Global(const TreeRef& tree) : Stmt(tree) {
638 tree_->match(TK_GLOBAL);
639 }
namesGlobal640 List<Ident> names() {
641 return List<Ident>(subtree(0));
642 }
createGlobal643 static Global create(const SourceRange& range, const List<Ident>& names) {
644 return Global(Compound::create(TK_GLOBAL, range, {names}));
645 }
646 };
647
648 struct AugAssignKind : public TreeView {
AugAssignKindAugAssignKind649 explicit AugAssignKind(const TreeRef& tree) : TreeView(tree) {
650 switch (tree->kind()) {
651 case '+':
652 case '-':
653 case '*':
654 case '/':
655 case '%':
656 case '|':
657 case '&':
658 case '^':
659 case TK_POW:
660 case TK_LSHIFT:
661 case TK_RSHIFT:
662 return;
663 default:
664 throw(ErrorReport(tree) << "is not a valid AugAssignKind");
665 }
666 }
667 };
668
669 // Augmented assignment, like "foo += bar"
670 struct AugAssign : public Stmt {
AugAssignAugAssign671 explicit AugAssign(const TreeRef& tree) : Stmt(tree) {
672 tree_->match(TK_AUG_ASSIGN);
673 }
createAugAssign674 static AugAssign create(
675 const SourceRange& range,
676 const Expr& lhs,
677 const AugAssignKind& aug_op,
678 const Expr& rhs) {
679 return AugAssign(
680 Compound::create(TK_AUG_ASSIGN, range, {lhs, aug_op, rhs}));
681 }
lhsAugAssign682 Expr lhs() const {
683 return Expr(subtree(0));
684 }
aug_opAugAssign685 int aug_op() const {
686 return subtree(1)->kind();
687 }
rhsAugAssign688 Expr rhs() const {
689 return Expr(subtree(2));
690 }
691 };
692
693 struct Assign : public Stmt {
AssignAssign694 explicit Assign(const TreeRef& tree) : Stmt(tree) {
695 tree_->match(TK_ASSIGN);
696 }
createAssign697 static Assign create(
698 const SourceRange& range,
699 const List<Expr>& lhs,
700 const Maybe<Expr>& rhs,
701 const Maybe<Expr>& type) {
702 return Assign(Compound::create(TK_ASSIGN, range, {lhs, rhs, type}));
703 }
704
lhs_listAssign705 List<Expr> lhs_list() const {
706 return List<Expr>(subtree(0));
707 }
708
lhsAssign709 Expr lhs() const {
710 const auto& li = lhs_list();
711 TORCH_INTERNAL_ASSERT(li.size() == 1);
712 return *li.begin();
713 }
714
rhsAssign715 Maybe<Expr> rhs() const {
716 return Maybe<Expr>(subtree(1));
717 }
718
typeAssign719 Maybe<Expr> type() const {
720 return Maybe<Expr>(subtree(2));
721 }
722 };
723
724 struct Return : public Stmt {
ReturnReturn725 explicit Return(const TreeRef& tree) : Stmt(tree) {
726 tree_->match(TK_RETURN);
727 }
exprReturn728 Expr expr() const {
729 return Expr(subtree(0));
730 }
createReturn731 static Return create(const SourceRange& range, const Expr& value) {
732 return Return(Compound::create(TK_RETURN, range, {value}));
733 }
734 };
735
736 struct Raise : public Stmt {
RaiseRaise737 explicit Raise(const TreeRef& tree) : Stmt(tree) {
738 tree_->match(TK_RAISE);
739 }
exprRaise740 Expr expr() const {
741 return Expr(subtree(0));
742 }
createRaise743 static Raise create(const SourceRange& range, const Expr& expr) {
744 return Raise(Compound::create(TK_RAISE, range, {expr}));
745 }
746 };
747
748 struct Assert : public Stmt {
AssertAssert749 explicit Assert(const TreeRef& tree) : Stmt(tree) {
750 tree_->match(TK_ASSERT);
751 }
testAssert752 Expr test() const {
753 return Expr(subtree(0));
754 }
msgAssert755 Maybe<Expr> msg() const {
756 return Maybe<Expr>(subtree(1));
757 }
createAssert758 static Assert create(
759 const SourceRange& range,
760 const Expr& test,
761 const Maybe<Expr>& msg) {
762 return Assert(Compound::create(TK_ASSERT, range, {test, msg}));
763 }
764 };
765
766 struct Pass : public Stmt {
PassPass767 explicit Pass(const TreeRef& tree) : Stmt(tree) {
768 tree_->match(TK_PASS);
769 }
createPass770 static Pass create(const SourceRange& range) {
771 return Pass(Compound::create(TK_PASS, range, {}));
772 }
773 };
774
775 struct Dots : public Expr {
DotsDots776 explicit Dots(const TreeRef& tree) : Expr(tree) {
777 tree_->match(TK_DOTS);
778 }
createDots779 static Dots create(const SourceRange& range) {
780 return Dots(Compound::create(TK_DOTS, range, {}));
781 }
782 };
783
784 struct Break : public Stmt {
BreakBreak785 explicit Break(const TreeRef& tree) : Stmt(tree) {
786 tree_->match(TK_BREAK);
787 }
createBreak788 static Break create(const SourceRange& range) {
789 return Break(Compound::create(TK_BREAK, range, {}));
790 }
791 };
792
793 struct Continue : public Stmt {
ContinueContinue794 explicit Continue(const TreeRef& tree) : Stmt(tree) {
795 tree_->match(TK_CONTINUE);
796 }
createContinue797 static Continue create(const SourceRange& range) {
798 return Continue(Compound::create(TK_CONTINUE, range, {}));
799 }
800 };
801
802 struct ExprStmt : public Stmt {
ExprStmtExprStmt803 explicit ExprStmt(const TreeRef& tree) : Stmt(tree) {
804 tree_->match(TK_EXPR_STMT);
805 }
exprExprStmt806 Expr expr() {
807 return Expr(subtree(0));
808 }
createExprStmt809 static ExprStmt create(const SourceRange& range, const Expr& list) {
810 return ExprStmt(Compound::create(TK_EXPR_STMT, range, {list}));
811 }
812 };
813
814 ////////////////////////////////////////////////////////////////////////////////
815 // Expressions
816 ////////////////////////////////////////////////////////////////////////////////
817
818 struct BinOp : public Expr {
BinOpBinOp819 explicit BinOp(const TreeRef& tree) : Expr(tree) {
820 switch (tree->kind()) {
821 case TK_AND:
822 case TK_OR:
823 case '<':
824 case '>':
825 case TK_IS:
826 case TK_ISNOT:
827 case TK_EQ:
828 case TK_LE:
829 case TK_GE:
830 case TK_NE:
831 case '+':
832 case '*':
833 case '/':
834 case '-':
835 case '@':
836 case TK_POW:
837 case TK_LSHIFT:
838 case TK_RSHIFT:
839 case '%':
840 case '&':
841 case '^':
842 case '|':
843 case TK_FLOOR_DIV:
844 case TK_IN:
845 if (tree->trees().size() != 2)
846 throw(
847 ErrorReport(tree)
848 << "BinOp expected 2 subtrees, found " << tree->trees().size());
849 return;
850 default:
851 throw(
852 ErrorReport(tree)
853 << kindToString(tree->kind()) << " is not a valid BinOp");
854 }
855 }
lhsBinOp856 Expr lhs() const {
857 return Expr(subtree(0));
858 }
rhsBinOp859 Expr rhs() const {
860 return Expr(subtree(1));
861 }
createBinOp862 static BinOp create(
863 const SourceRange& range,
864 int kind,
865 const Expr& lhs,
866 const Expr& rhs) {
867 return BinOp(Compound::create(kind, range, {lhs, rhs}));
868 }
869 };
870
871 struct UnaryOp : public Expr {
UnaryOpUnaryOp872 explicit UnaryOp(const TreeRef& tree) : Expr(tree) {
873 switch (tree->kind()) {
874 case TK_UNARY_MINUS:
875 case '~':
876 case TK_NOT:
877 if (tree->trees().size() != 1)
878 throw(
879 ErrorReport(tree)
880 << "UnaryOp expected 1 subtree, found " << tree->trees().size());
881 return;
882 default:
883 throw(
884 ErrorReport(tree)
885 << kindToString(tree->kind()) << " is not a valid UnaryOp");
886 }
887 }
createUnaryOp888 static UnaryOp create(const SourceRange& range, int kind, const Expr& expr) {
889 return UnaryOp(Compound::create(kind, range, {expr}));
890 }
891 };
892
893 struct Const : public Expr {
ConstConst894 explicit Const(const TreeRef& tree) : Expr(tree) {
895 tree_->matchNumSubtrees(TK_CONST, 1);
896 }
isFloatingPointConst897 bool isFloatingPoint() const {
898 if (isComplex())
899 return false;
900
901 bool is_inf = subtree(0)->stringValue() == "inf";
902 return is_inf ||
903 subtree(0)->stringValue().find_first_of(".eE") != std::string::npos;
904 }
isIntegralConst905 bool isIntegral() const {
906 return !isFloatingPoint() && !isComplex();
907 }
isComplexConst908 bool isComplex() const {
909 return subtree(0)->stringValue().find_first_of('j') != std::string::npos;
910 }
asIntegralConst911 int64_t asIntegral() const {
912 try {
913 return std::stoll(subtree(0)->stringValue(), nullptr, 0);
914 } catch (const std::out_of_range&) {
915 throw(
916 ErrorReport(range()) << "Integral constant out of range "
917 "(must fit in a signed 64 bit integer)");
918 }
919 }
asFloatingPointConst920 double asFloatingPoint() const {
921 // We can't pass in nullptr as the dummy pointer gets dereferenced for
922 // Android version of strtod_c().
923 char* dummy = nullptr;
924 return torch::jit::strtod_c(subtree(0)->stringValue().c_str(), &dummy);
925 }
asComplexConst926 c10::complex<double> asComplex() const {
927 char* dummy = nullptr;
928 auto str = subtree(0)->stringValue();
929 // Complex numbers (a+bj, where a is non-zero) are parsed as an addition
930 // between float/int a and a complex number "bj". When a is 0, a complex
931 // number bj is created as above. So, while parsing the string, we don't
932 // have to worry about the real component of the complex number.
933 auto imag =
934 torch::jit::strtod_c(str.substr(0, str.size() - 1).c_str(), &dummy);
935 return c10::complex<double>(0, imag);
936 }
textConst937 const std::string& text() const {
938 return subtree(0)->stringValue();
939 }
createConst940 static Const create(const SourceRange& range, const std::string& value) {
941 return Const(Compound::create(TK_CONST, range, {String::create(value)}));
942 }
943 };
944
945 struct StringLiteral : public Expr {
StringLiteralStringLiteral946 explicit StringLiteral(const TreeRef& tree) : Expr(tree) {
947 tree_->matchNumSubtrees(TK_STRINGLITERAL, 1);
948 }
textStringLiteral949 const std::string& text() const {
950 return subtree(0)->stringValue();
951 }
createStringLiteral952 static StringLiteral create(
953 const SourceRange& range,
954 const std::string& value) {
955 return StringLiteral(
956 Compound::create(TK_STRINGLITERAL, range, {String::create(value)}));
957 }
958 };
959
960 struct Apply : public Expr {
ApplyApply961 explicit Apply(const TreeRef& tree) : Expr(tree) {
962 tree_->match(TK_APPLY);
963 }
calleeApply964 Expr callee() const {
965 return Expr(subtree(0));
966 }
inputsApply967 List<Expr> inputs() const {
968 return List<Expr>(subtree(1));
969 }
attributesApply970 List<Attribute> attributes() const {
971 return List<Attribute>(subtree(2));
972 }
createApply973 static Apply create(
974 const SourceRange& range,
975 const Expr& callee,
976 const List<Expr>& inputs,
977 const List<Attribute>& attributes) {
978 return Apply(
979 Compound::create(TK_APPLY, range, {callee, inputs, attributes}));
980 }
981 };
982
983 struct Select : public Expr {
SelectSelect984 explicit Select(const TreeRef& tree) : Expr(tree) {
985 tree_->match('.');
986 }
valueSelect987 Expr value() const {
988 return Expr(subtree(0));
989 }
selectorSelect990 Ident selector() const {
991 return Ident(subtree(1));
992 }
createSelect993 static Select create(
994 const SourceRange& range,
995 const Expr& value,
996 const Ident& selector) {
997 return Select(Compound::create('.', range, {value, selector}));
998 }
999 };
1000
1001 struct SliceExpr : public Expr {
SliceExprSliceExpr1002 explicit SliceExpr(const TreeRef& tree) : Expr(tree) {
1003 tree_->match(TK_SLICE_EXPR);
1004 }
startSliceExpr1005 Maybe<Expr> start() const {
1006 return Maybe<Expr>(subtree(0));
1007 }
endSliceExpr1008 Maybe<Expr> end() const {
1009 return Maybe<Expr>(subtree(1));
1010 }
stepSliceExpr1011 Maybe<Expr> step() const {
1012 return Maybe<Expr>(subtree(2));
1013 }
startOrSliceExpr1014 Expr startOr(int64_t alternative) const {
1015 const auto startOption = start();
1016 return startOption.present() ? startOption.get() : createInt(alternative);
1017 }
endOrSliceExpr1018 Expr endOr(int64_t alternative) const {
1019 const auto endOption = end();
1020 return endOption.present() ? endOption.get() : createInt(alternative);
1021 }
stepOrSliceExpr1022 Expr stepOr(int64_t alternative) const {
1023 const auto stepOption = step();
1024 return stepOption.present() ? stepOption.get() : createInt(alternative);
1025 }
createSliceExpr1026 static SliceExpr create(
1027 const SourceRange& range,
1028 const Maybe<Expr>& start,
1029 const Maybe<Expr>& end,
1030 const Maybe<Expr>& step) {
1031 return SliceExpr(
1032 Compound::create(TK_SLICE_EXPR, range, {start, end, step}));
1033 }
1034
1035 private:
createIntSliceExpr1036 Expr createInt(int64_t value) const {
1037 return Expr(Const::create(range(), std::to_string(value)));
1038 }
1039 };
1040
1041 struct Subscript : public Expr {
SubscriptSubscript1042 explicit Subscript(const TreeRef& tree) : Expr(tree) {
1043 tree_->match(TK_SUBSCRIPT);
1044 }
valueSubscript1045 Expr value() const {
1046 return Expr(subtree(0));
1047 }
subscript_exprsSubscript1048 List<Expr> subscript_exprs() const {
1049 return List<Expr>(subtree(1));
1050 }
createSubscript1051 static Subscript create(
1052 const SourceRange& range,
1053 const Expr& value,
1054 const List<Expr>& subscript_exprs) {
1055 auto whole_range = SourceRange(
1056 range.source(), range.start(), subscript_exprs.range().end() + 1);
1057 return Subscript(
1058 Compound::create(TK_SUBSCRIPT, whole_range, {value, subscript_exprs}));
1059 }
1060 };
1061
1062 struct Var : public Expr {
VarVar1063 explicit Var(const TreeRef& tree) : Expr(tree) {
1064 tree_->match(TK_VAR);
1065 };
nameVar1066 Ident name() const {
1067 return Ident(subtree(0));
1068 }
createVar1069 static Var create(const SourceRange& range, const Ident& name) {
1070 return Var(Compound::create(TK_VAR, range, {name}));
1071 }
1072 };
1073
1074 // WithItem represents an item using with a WithStmt.
1075 struct WithItem : public Expr {
WithItemWithItem1076 explicit WithItem(const TreeRef& tree) : Expr(tree) {
1077 tree_->match(TK_WITH_ITEM);
1078 }
1079
targetWithItem1080 Expr target() const {
1081 return Expr(subtree(0));
1082 }
1083
varWithItem1084 Maybe<Var> var() const {
1085 return Maybe<Var>(subtree(1));
1086 }
1087
createWithItem1088 static WithItem create(
1089 const SourceRange& range,
1090 const Expr& target,
1091 const Maybe<Var>& var) {
1092 return WithItem(Compound::create(TK_WITH_ITEM, range, {target, var}));
1093 }
1094 };
1095
1096 // With represents a with statement consisting of a list of with items and a
1097 // body of statements.
1098 struct With : public Stmt {
WithWith1099 explicit With(const TreeRef& tree) : Stmt(tree) {
1100 tree_->match(TK_WITH);
1101 }
1102
targetsWith1103 List<WithItem> targets() const {
1104 return List<WithItem>(subtree(0));
1105 }
1106
bodyWith1107 List<Stmt> body() const {
1108 return List<Stmt>(subtree(1));
1109 }
1110
createWith1111 static With create(
1112 const SourceRange& range,
1113 const List<WithItem>& targets,
1114 const List<Stmt>& body) {
1115 return With(Compound::create(TK_WITH, range, {targets, body}));
1116 }
1117 };
1118
1119 struct TernaryIf : public Expr {
TernaryIfTernaryIf1120 explicit TernaryIf(const TreeRef& tree) : Expr(tree) {
1121 tree_->matchNumSubtrees(TK_IF_EXPR, 3);
1122 };
condTernaryIf1123 Expr cond() const {
1124 return Expr(subtree(0));
1125 }
true_exprTernaryIf1126 Expr true_expr() const {
1127 return Expr(subtree(1));
1128 }
false_exprTernaryIf1129 Expr false_expr() const {
1130 return Expr(subtree(2));
1131 }
createTernaryIf1132 static TernaryIf create(
1133 const SourceRange& range,
1134 const Expr& cond,
1135 const Expr& true_expr,
1136 const Expr& false_expr) {
1137 return TernaryIf(
1138 Compound::create(TK_IF_EXPR, range, {cond, true_expr, false_expr}));
1139 };
1140 };
1141
1142 struct ListLiteral : public Expr {
ListLiteralListLiteral1143 explicit ListLiteral(const TreeRef& tree) : Expr(tree) {
1144 tree_->match(TK_LIST_LITERAL);
1145 }
inputsListLiteral1146 List<Expr> inputs() const {
1147 return subtree(0);
1148 }
createListLiteral1149 static ListLiteral create(
1150 const SourceRange& range,
1151 const List<Expr>& inputs) {
1152 return ListLiteral(Compound::create(TK_LIST_LITERAL, range, {inputs}));
1153 }
1154 };
1155
1156 struct TupleLiteral : public Expr {
TupleLiteralTupleLiteral1157 explicit TupleLiteral(const TreeRef& tree) : Expr(tree) {
1158 tree_->match(TK_TUPLE_LITERAL);
1159 }
inputsTupleLiteral1160 List<Expr> inputs() const {
1161 return subtree(0);
1162 }
createTupleLiteral1163 static TupleLiteral create(
1164 const SourceRange& range,
1165 const List<Expr>& inputs) {
1166 return TupleLiteral(Compound::create(TK_TUPLE_LITERAL, range, {inputs}));
1167 }
1168 };
1169
1170 struct DictLiteral : public Expr {
DictLiteralDictLiteral1171 explicit DictLiteral(const TreeRef& tree) : Expr(tree) {
1172 tree_->match(TK_DICT_LITERAL);
1173 }
key_inputsDictLiteral1174 List<Expr> key_inputs() const {
1175 return subtree(0);
1176 }
value_inputsDictLiteral1177 List<Expr> value_inputs() const {
1178 return subtree(1);
1179 }
createDictLiteral1180 static DictLiteral create(
1181 const SourceRange& range,
1182 const List<Expr>& keys,
1183 const List<Expr>& values) {
1184 return DictLiteral(
1185 Compound::create(TK_DICT_LITERAL, range, {keys, values}));
1186 }
1187 };
1188
1189 struct Starred : public Expr {
StarredStarred1190 explicit Starred(const TreeRef& tree) : Expr(tree) {
1191 tree_->match(TK_STARRED);
1192 }
exprStarred1193 Expr expr() const {
1194 return Expr(subtree(0));
1195 }
createStarred1196 static Starred create(const SourceRange& range, const Expr& expr) {
1197 return Starred(Compound::create(TK_STARRED, range, {expr}));
1198 }
1199 };
1200
1201 struct Delete : public Stmt {
DeleteDelete1202 explicit Delete(const TreeRef& tree) : Stmt(tree) {
1203 tree_->match(TK_DELETE);
1204 }
targetsDelete1205 List<Expr> targets() const {
1206 return subtree(0);
1207 }
createDelete1208 static Delete create(const SourceRange& range, const List<Expr>& targets) {
1209 return Delete(Compound::create(TK_DELETE, range, {targets}));
1210 }
1211 };
1212
1213 /*
1214 * NOTE: transforming PEP 604 union into equivalent union type
1215 *
1216 * NOTE: Union[int, float] parses into:
1217 * <EXPR> expr:(subscript
1218 * (variable (ident Union))
1219 * (list
1220 * (variable (ident int))
1221 * (variable (ident float))))
1222 * <KIND> subscript
1223 *
1224 * NOTE: (int | float) parses into:
1225 * <EXPR> expr:(|
1226 * (variable (ident int))
1227 * (variable (ident float)))
1228 * <KIND> |
1229 */
1230
_flatten_pep604_union(const torch::jit::Expr & node,std::vector<torch::jit::Expr> * result)1231 inline void _flatten_pep604_union(
1232 const torch::jit::Expr& node,
1233 std::vector<torch::jit::Expr>* result) {
1234 // flatten possibly nested union expressions like (int | (float | str))
1235 // into a flat list of expressions like [int, float, str]
1236 if (node.kind() == '|') {
1237 auto as_binop = torch::jit::BinOp(node);
1238 _flatten_pep604_union(as_binop.lhs(), result);
1239 _flatten_pep604_union(as_binop.rhs(), result);
1240 } else {
1241 result->push_back(node);
1242 }
1243 }
1244
get_pep604_union_members(const Expr & node)1245 inline std::vector<Expr> get_pep604_union_members(const Expr& node) {
1246 std::vector<Expr> result;
1247 _flatten_pep604_union(node, &result);
1248 return result;
1249 }
1250
1251 // Flattens a PEP 604 union into a classical union.
1252 // For example, ((x | y) | z) is transformed into Union[x, y, z].
pep604union_to_union(const Expr & expr)1253 inline Expr pep604union_to_union(const Expr& expr) {
1254 // noop if not a pep604 union
1255 if (expr.kind() != '|')
1256 return expr;
1257
1258 // In order to support unions with more than 2 operands ((x|y)|z), we need to
1259 // recursively flatten the tree of | expressions.
1260 auto members = get_pep604_union_members(expr);
1261 auto synthesised_union = Subscript::create(
1262 expr.range(),
1263 Var::create(expr.range(), Ident::create(expr.range(), "Union")),
1264 List<Expr>::create(expr.range(), members));
1265 return std::move(synthesised_union);
1266 }
1267
1268 } // namespace torch::jit
1269
1270 namespace std {
1271
1272 template <typename T>
1273 struct iterator_traits<torch::jit::ListIterator<T>>
1274 : std::iterator_traits<torch::jit::TreeList::const_iterator> {};
1275
1276 } // namespace std
1277