xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/lexer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/macros/Macros.h>
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/csrc/jit/frontend/parser_constants.h>
6 #include <torch/csrc/jit/frontend/source_range.h>
7 #include <torch/csrc/jit/frontend/strtod.h>
8 #include <algorithm>
9 #include <clocale>
10 #include <cstdlib>
11 #include <memory>
12 #include <sstream>
13 #include <string>
14 #include <vector>
15 
16 namespace torch::jit {
17 
18 // single character tokens are just the character itself '+'
19 // multi-character tokens need an entry here
20 // if the third entry is not the empty string, it is used
21 // in the lexer to match this token.
22 
23 // These kinds are also used in Tree.h as the kind of the AST node.
24 // Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the
25 // lexer.
26 
27 #define TC_FORALL_TOKEN_KINDS(_)                 \
28   _(TK_EOF, "eof", "")                           \
29   _(TK_WHITESPACE, "whitespace", "")             \
30   _(TK_WHITESPACE_EOF, "whitespace_eof", "")     \
31   _(TK_NUMBER, "number", "")                     \
32   _(TK_NEWLINE, "newline", "")                   \
33   _(TK_INDENT, "indent", "")                     \
34   _(TK_DEDENT, "dedent", "")                     \
35   _(TK_DEF, "def", "def")                        \
36   _(TK_EQUIVALENT, "equivalent", "<=>")          \
37   _(TK_IDENT, "ident", "")                       \
38   _(TK_STRING, "string", "")                     \
39   _(TK_STRINGLITERAL, "string_literal", "")      \
40   _(TK_CONST, "const", "")                       \
41   _(TK_LIST, "list", "")                         \
42   _(TK_DICT, "dict", "")                         \
43   _(TK_OPTION, "option", "")                     \
44   _(TK_APPLY, "apply", "")                       \
45   _(TK_COMPREHENSION, "comprehension", "")       \
46   _(TK_RANGE_CONSTRAINT, "range_constraint", "") \
47   _(TK_PARAM, "param", "")                       \
48   _(TK_INFERRED, "inferred", "")                 \
49   _(TK_ACCESS, "access", "")                     \
50   _(TK_ASSIGN, "assign", "")                     \
51   _(TK_AUG_ASSIGN, "aug_assign", "")             \
52   _(TK_ATTRIBUTE, "attribute", "")               \
53   _(TK_IF, "if", "if")                           \
54   _(TK_ELSE, "else", "else")                     \
55   _(TK_ELIF, "elif", "elif")                     \
56   _(TK_WHILE, "while", "while")                  \
57   _(TK_EXPR_STMT, "expression statement", "")    \
58   _(TK_RETURN, "return", "return")               \
59   _(TK_IS, "is", "is")                           \
60   _(TK_ISNOT, "is not", "is not")                \
61   _(TK_NE, "ne", "!=")                           \
62   _(TK_EQ, "eq", "==")                           \
63   _(TK_LE, "le", "<=")                           \
64   _(TK_GE, "ge", ">=")                           \
65   _(TK_FLOOR_DIV, "floordiv", "//")              \
66   _(TK_IF_EXPR, "if", "")                        \
67   _(TK_TRUE, "True", "True")                     \
68   _(TK_FALSE, "False", "False")                  \
69   _(TK_NONE, "None", "None")                     \
70   _(TK_AND, "and", "and")                        \
71   _(TK_OR, "or", "or")                           \
72   _(TK_NOT, "not", "not")                        \
73   _(TK_LSHIFT, "<<", "<<")                       \
74   _(TK_RSHIFT, ">>", ">>")                       \
75   _(TK_CAST, "cast", "")                         \
76   _(TK_PLUS_EQ, "+=", "+=")                      \
77   _(TK_MINUS_EQ, "-=", "-=")                     \
78   _(TK_TIMES_EQ, "*=", "*=")                     \
79   _(TK_DIV_EQ, "/=", "/=")                       \
80   _(TK_MOD_EQ, "%=", "%=")                       \
81   _(TK_BIT_OR_EQ, "|=", "|=")                    \
82   _(TK_BIT_AND_EQ, "&=", "&=")                   \
83   _(TK_BIT_XOR_EQ, "^=", "^=")                   \
84   _(TK_LSHIFT_EQ, "<<=", "<<=")                  \
85   _(TK_RSHIFT_EQ, ">>=", ">>=")                  \
86   _(TK_POW_EQ, "**=", "**=")                     \
87   _(TK_GLOBAL, "global", "global")               \
88   _(TK_BUILT_IN, "built-in", "")                 \
89   _(TK_SUBSCRIPT, "subscript", "")               \
90   _(TK_VAR, "variable", "")                      \
91   _(TK_NOTHING, "nothing", "")                   \
92   _(TK_DICT_LITERAL, "dict-literal", "")         \
93   _(TK_LIST_LITERAL, "list-literal", "")         \
94   _(TK_TUPLE_LITERAL, "tuple-literal", "")       \
95   _(TK_FOR, "for", "for")                        \
96   _(TK_IN, "in", "in")                           \
97   _(TK_NOTIN, "not in", "not in")                \
98   _(TK_STARRED, "starred", "")                   \
99   _(TK_UNARY_MINUS, "unary minus", "")           \
100   _(TK_POW, "pow operator", "**")                \
101   _(TK_ARROW, "arrow", "->")                     \
102   _(TK_DECL, "decl", "")                         \
103   _(TK_SLICE_EXPR, "slice expr", "")             \
104   _(TK_TYPE_COMMENT, "type comment", "# type:")  \
105   _(TK_RAISE, "raise", "raise")                  \
106   _(TK_ASSERT, "assert", "assert")               \
107   _(TK_DOTS, "dots", "...")                      \
108   _(TK_LIST_COMP, "list comprehension", "")      \
109   _(TK_DICT_COMP, "dict comprehension", "")      \
110   _(TK_BREAK, "break", "break")                  \
111   _(TK_CONTINUE, "continue", "continue")         \
112   _(TK_DELETE, "del", "del")                     \
113   _(TK_PASS, "pass", "pass")                     \
114   _(TK_CLASS_DEF, "class", "class")              \
115   _(TK_IMPORT, "import", "import")               \
116   _(TK_WITH, "with", "with")                     \
117   _(TK_WITH_ITEM, "withitem", "")                \
118   _(TK_AS, "as", "as")                           \
119   _(TK_PROP, "property", "")                     \
120   _(TK_ELLIPSIS, "Ellipsis", "Ellipsis")         \
121   _(TK_NONE_TYPE, "NoneType", "NoneType")
122 
123 enum TokenKind {
124   // we use characters to represent themselves so skip all valid characters
125   // before
126   // assigning enum values to multi-char tokens.
127   TK_DUMMY_START = 256,
128 #define DEFINE_TOKEN(tok, _, _2) tok,
129   TC_FORALL_TOKEN_KINDS(DEFINE_TOKEN)
130 #undef DEFINE_TOKEN
131 };
132 
133 TORCH_API std::string kindToString(int kind);
134 TORCH_API int stringToKind(const std::string& str);
135 
136 // nested hash tables that indicate char-by-char what is a valid token.
137 struct TokenTrie;
138 using TokenTrieRef = std::unique_ptr<TokenTrie>;
139 struct TokenTrie {
140   TokenTrie() = default;
insertTokenTrie141   void insert(const char* str, int tok) {
142     if (*str == '\0') {
143       AT_ASSERT(kind == 0);
144       kind = tok;
145       return;
146     }
147 
148     for (size_t i = 0, e = child_chars.size(); i < e; ++i) {
149       if (child_chars[i] == *str) {
150         child_tries[i]->insert(str + 1, tok);
151         return;
152       }
153     }
154 
155     child_chars.emplace_back(*str);
156     child_tries.emplace_back(std::make_unique<TokenTrie>());
157     child_tries.back()->insert(str + 1, tok);
158   }
159   int kind{0}; // 0 == invalid token
160 
161   std::vector<char> child_chars;
162   std::vector<TokenTrieRef> child_tries;
163 };
164 
165 // stuff that is shared against all TC lexers/parsers and is initialized only
166 // once.
167 struct TORCH_API SharedParserData {
SharedParserDataSharedParserData168   SharedParserData() : head(new TokenTrie()) {
169     for (const char* c = valid_single_char_tokens; *c; c++) {
170       std::string str(1, *c);
171       head->insert(str.c_str(), *c);
172     }
173 
174 #define ADD_CASE(tok, _, tokstring)   \
175   if (*(tokstring) != '\0') {         \
176     head->insert((tokstring), (tok)); \
177   }
178     TC_FORALL_TOKEN_KINDS(ADD_CASE)
179 #undef ADD_CASE
180   }
181 
matchSharedParserData182   bool match(
183       StringCordView::Iterator pos,
184       bool continuation, // are we inside a scope where newlines don't count
185                          // (e.g. inside parens)
186       bool whitespace_token, // should we treat whitespace as a token
187       int* kind,
188       StringCordView::Iterator* start,
189       StringCordView::Iterator* end) {
190     *start = pos;
191     // skip whitespace
192     while (pos.has_next() && isblank(*pos)) {
193       ++pos;
194     }
195 
196     // special handling
197     if (pos.has_next()) {
198       if (*pos == '#' && !isTypeComment(pos)) {
199         // skip comments
200         while (pos.has_next() && *pos != '\n')
201           ++pos;
202         // tail call, handle whitespace and more comments
203         return match(pos, continuation, whitespace_token, kind, start, end);
204       }
205       if (*pos == '\\') {
206         auto newiter = pos;
207         ++newiter;
208         if (newiter.has_next() && *newiter == '\n' && !whitespace_token) {
209           ++newiter;
210           return match(newiter, continuation, false, kind, start, end);
211         }
212       }
213       if (*pos == '\n') {
214         return match(++pos, continuation, !continuation, kind, start, end);
215       }
216     }
217     // we handle white space before EOF because in the case we have something
218     // like the following where we need to generate the dedent token if foo:
219     //   ...
220     // else:
221     //   pass
222     if (whitespace_token) {
223       *kind = !pos.has_next() ? TK_WHITESPACE_EOF : TK_WHITESPACE;
224       *end = pos;
225       return true;
226     }
227     if (!pos.has_next()) {
228       *kind = TK_EOF;
229       *start = pos;
230       *end = *start;
231       return true;
232     }
233     // invariant: the next token is not whitespace or newline
234     *start = pos;
235     // check for a valid number
236     size_t len = 0;
237     if (isNumber(pos.rest_line(), 0, &len)) {
238       *end = *start;
239       *end += len;
240       *kind = TK_NUMBER;
241       return true;
242     }
243     // check for string
244     if (isString(pos.rest_line(), 0, &len)) {
245       *kind = TK_STRINGLITERAL;
246       *end = *start;
247       *end += len;
248       return true;
249     }
250 
251     // check for either an ident or a token
252     // ident tracks whether what we have scanned so far could be an identifier
253     // matched indicates if we have found any match.
254     bool matched = false;
255     bool ident = true;
256     TokenTrie* cur = head.get();
257     // for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr);
258     // i++)
259     for (size_t i = 0; pos.has_next() && (ident || cur != nullptr);
260          ++pos, ++i) {
261       ident = ident && validIdent(i, *pos);
262       if (ident) {
263         matched = true;
264         *end = pos.next_iter();
265         *kind = TK_IDENT;
266       }
267       // check for token second, so that e.g. 'max' matches the token TK_MAX
268       // rather the
269       // identifier 'max'
270       if (cur) {
271         const auto begin_it = cur->child_chars.begin();
272         const auto end_it = cur->child_chars.end();
273         const auto ch_it = std::find(begin_it, end_it, *pos);
274 
275         cur = (ch_it == end_it) ? nullptr
276                                 : cur->child_tries[ch_it - begin_it].get();
277 
278         if (cur && cur->kind != 0) {
279           matched = true;
280           *end = pos.next_iter();
281           *kind = cur->kind;
282         }
283       }
284     }
285     return matched;
286   }
287 
288   bool isUnary(int kind, int* prec);
289   bool isBinary(int kind, int* prec);
isRightAssociativeSharedParserData290   bool isRightAssociative(int kind) {
291     switch (kind) {
292       case '?':
293       case TK_POW:
294       case TK_IF:
295         return true;
296       default:
297         return false;
298     }
299   }
300 
301  private:
validIdentSharedParserData302   bool validIdent(size_t i, char n) {
303     return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
304   }
305 
306   // 1. skip whitespace
307   // 2. handle comment or newline
308   //
isNumberSharedParserData309   bool isNumber(c10::string_view str, size_t start, size_t* len) {
310     char first = str[start];
311     // strtod allows numbers to start with + or - or nan or inf
312     // http://en.cppreference.com/w/cpp/string/byte/strtof
313     // but we want only the number part, otherwise 1+3 will turn into two
314     // adjacent numbers in the lexer
315     if (first == '-' || first == '+' || isalpha(first))
316       return false;
317     const char* startptr = str.data() + start;
318     char* endptr = nullptr;
319     torch::jit::strtod_c(startptr, &endptr);
320     *len = endptr - startptr;
321     // check if the number is complex valued
322     // access is safe because string is assumed to be null terminated
323     if (endptr != nullptr && *endptr == 'j') {
324       *len += 1;
325     }
326     return *len > 0;
327   }
328 
isCharCountSharedParserData329   bool isCharCount(char c, c10::string_view str, size_t start, int len) {
330     // count checks from [start, start + len)
331     return start + len <= str.size() &&
332         std::count(str.begin() + start, str.begin() + start + len, c) == len;
333   }
334 
335   // python concatenates all adjacent strings "a" "b" == "ab"
336   // strings can be enclosed with 1 or 3 single or double quotes
337   // if enclosed with 3 quotes newlines are valid
338   // as elsewhere, backslash and new line should be ignored
isStringSharedParserData339   bool isString(c10::string_view str, size_t start, size_t* len) {
340     char quote = str[start];
341     if (quote != '\"' && quote != '\'')
342       return false;
343     int quote_len = isCharCount(quote, str, start, 3) ? 3 : 1;
344 
345     // end is now set past the opening quotation marks
346     size_t end = start + quote_len;
347     while (end < str.size() && !isCharCount(quote, str, end, quote_len)) {
348       if (str[end] == '\n' && quote_len != 3) {
349         return false;
350       }
351       // handle escaped characters. advances past escaped quotation marks,
352       // escaped newlines and escaped backslashes
353       // multi-char escapes like \x1A are handled fine here because the
354       // remainder of the escape are valid string characters anyway
355       if (str[end] == '\\') {
356         end++;
357       }
358       end++;
359     }
360     // set length equal to the complete string including quotations
361     *len = end - start + quote_len;
362     // if end finished without going past the last character of the string than
363     // there is a match
364     return end < str.size();
365   }
366 
isblankSharedParserData367   bool isblank(int n) {
368     return isspace(n) && n != '\n';
369   }
370 
isTypeCommentSharedParserData371   bool isTypeComment(StringCordView::Iterator str_iter) {
372     c10::string_view rest_line = str_iter.rest_line();
373     const std::string type_string = "# type:";
374     if (rest_line.size() < type_string.length()) {
375       return false;
376     }
377     auto match_string = rest_line.substr(0, type_string.size());
378     return match_string == type_string;
379   }
380 
381   // Make an exception ignoring comments for type annotation comments
isTypeCommentSharedParserData382   bool isTypeComment(const StringCordView& str, size_t pos) {
383     const std::string type_string = "# type:";
384     if (str.size() < pos + type_string.length()) {
385       return false;
386     }
387     auto match_string = str.substr(pos, type_string.size());
388     return match_string == type_string;
389   }
390 
391   TokenTrieRef head;
392 };
393 
394 TORCH_API SharedParserData& sharedParserData();
395 
396 struct Token {
397   int kind;
398   SourceRange range;
TokenToken399   Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {}
textToken400   std::string text() {
401     return std::string(range.token_text());
402   }
kindStringToken403   std::string kindString() const {
404     return kindToString(kind);
405   }
406 };
407 
408 struct Lexer {
LexerLexer409   explicit Lexer(std::shared_ptr<Source> source)
410       : source(std::move(source)),
411 
412         indent_stack(),
413         next_tokens(),
414         shared(sharedParserData()) {
415     auto first_indent = lexRaw(true);
416     indent_stack.push_back(first_indent.range.size());
417     lex();
418   }
419   // Return the current token, and then move to the next one
nextLexer420   Token next() {
421     if (next_tokens.empty())
422       reportError("Lexer invariant violated: empty token queue");
423     Token r = std::move(next_tokens.front());
424     next_tokens.erase(next_tokens.begin());
425     if (next_tokens.empty()) {
426       lex();
427     }
428     return r;
429   }
430   // Skip the current token if it matches the given kind
nextIfLexer431   bool nextIf(int kind) {
432     if (cur().kind != kind)
433       return false;
434     next();
435     return true;
436   }
437 
reportErrorLexer438   [[noreturn]] void reportError(const std::string& what) {
439     reportError(what, cur());
440   }
reportErrorLexer441   [[noreturn]] void reportError(const std::string& what, const Token& t) {
442     std::stringstream ss;
443     ss << what << ":\n";
444     t.range.highlight(ss);
445     throw std::runtime_error(ss.str());
446   }
expectedLexer447   [[noreturn]] void expected(const std::string& what, const Token& t) {
448     std::stringstream ss;
449     ss << "expected " << what << " but found '" << t.kindString()
450        << "' here:\n";
451     t.range.highlight(ss);
452     throw std::runtime_error(ss.str());
453   }
expectedLexer454   [[noreturn]] void expected(const std::string& what) {
455     expected(what, cur());
456   }
457   // Check that the current token has a given kind, return the current token,
458   // and advance to the next one.
expectLexer459   Token expect(int kind) {
460     if (cur().kind != kind) {
461       expected(kindToString(kind));
462     }
463     return next();
464   }
lookaheadLexer465   Token& lookahead() {
466     if (next_tokens.size() < 2) {
467       lex();
468     }
469     return next_tokens[1];
470   }
curLexer471   Token& cur() {
472     return next_tokens.front();
473   }
474 
475  private:
lexLexer476   void lex() {
477     auto r = lexRaw();
478     switch (r.kind) {
479       case '(':
480       case '[':
481       case '{':
482         nesting++;
483         break;
484       case ')':
485       case ']':
486       case '}':
487         nesting--;
488         break;
489       case TK_WHITESPACE:
490       case TK_WHITESPACE_EOF: {
491         const auto depth =
492             r.kind == TK_WHITESPACE_EOF ? indent_stack.front() : r.range.size();
493         // note: TK_WHITESPACE_EOF is whitespace right before the EOF token
494         // just like we allow the code to be indented to a particular initial
495         // indent level, we allow the final indent to be anything and set
496         // it back to the initial indent level. This allows the code to be
497         // put into string literals inside code without worrying about final
498         // whitespace
499         if (depth > indent_stack.back()) {
500           indent_stack.push_back(depth);
501           r.kind = TK_INDENT;
502         } else if (depth == indent_stack.back()) {
503           r.kind = TK_NEWLINE;
504         } else {
505           next_tokens.emplace_back(TK_NEWLINE, r.range);
506           while (indent_stack.back() != depth) {
507             indent_stack.pop_back();
508             next_tokens.emplace_back(TK_DEDENT, r.range);
509             if (indent_stack.empty()) {
510               reportError("invalid indent level " + std::to_string(depth), r);
511             }
512           }
513           return; // We've already queued the tokens
514         }
515       } break;
516       default:
517         break;
518     }
519     next_tokens.push_back(std::move(r));
520   }
521   Token lexRaw(bool whitespace_token = false) {
522     AT_ASSERT(source);
523     if (current == nullptr) {
524       AT_ASSERT(pos == 0);
525       current = std::make_unique<StringCordView::Iterator>(
526           source->text_str().begin());
527     }
528 
529     StringCordView::Iterator start_iter = *current;
530     StringCordView::Iterator end_iter = *current;
531     int kind = 0;
532     if (!shared.match(
533             *current,
534             nesting > 0,
535             whitespace_token,
536             &kind,
537             &start_iter,
538             &end_iter)) {
539       expected(
540           "a valid token",
541           Token(
542               **current,
543               SourceRange(source, start_iter, start_iter.pos() + 1)));
544     }
545 
546     auto t = Token(kind, SourceRange(source, start_iter, end_iter.pos()));
547     pos = end_iter.pos();
548     *current = end_iter;
549     return t;
550   }
551 
552   std::shared_ptr<Source> source;
553   std::unique_ptr<StringCordView::Iterator> current;
554   size_t pos{0};
555   size_t nesting{0}; // depth of ( [ { nesting...
556   std::vector<size_t> indent_stack; // stack of indentation level of blocks
557   // Invariant: this should always contain at least a single element
558   std::vector<Token> next_tokens;
559   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
560   SharedParserData& shared;
561 };
562 } // namespace torch::jit
563