1 #pragma once 2 #include <c10/macros/Macros.h> 3 #include <c10/util/Exception.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/csrc/jit/frontend/parser_constants.h> 6 #include <torch/csrc/jit/frontend/source_range.h> 7 #include <torch/csrc/jit/frontend/strtod.h> 8 #include <algorithm> 9 #include <clocale> 10 #include <cstdlib> 11 #include <memory> 12 #include <sstream> 13 #include <string> 14 #include <vector> 15 16 namespace torch::jit { 17 18 // single character tokens are just the character itself '+' 19 // multi-character tokens need an entry here 20 // if the third entry is not the empty string, it is used 21 // in the lexer to match this token. 22 23 // These kinds are also used in Tree.h as the kind of the AST node. 24 // Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the 25 // lexer. 26 27 #define TC_FORALL_TOKEN_KINDS(_) \ 28 _(TK_EOF, "eof", "") \ 29 _(TK_WHITESPACE, "whitespace", "") \ 30 _(TK_WHITESPACE_EOF, "whitespace_eof", "") \ 31 _(TK_NUMBER, "number", "") \ 32 _(TK_NEWLINE, "newline", "") \ 33 _(TK_INDENT, "indent", "") \ 34 _(TK_DEDENT, "dedent", "") \ 35 _(TK_DEF, "def", "def") \ 36 _(TK_EQUIVALENT, "equivalent", "<=>") \ 37 _(TK_IDENT, "ident", "") \ 38 _(TK_STRING, "string", "") \ 39 _(TK_STRINGLITERAL, "string_literal", "") \ 40 _(TK_CONST, "const", "") \ 41 _(TK_LIST, "list", "") \ 42 _(TK_DICT, "dict", "") \ 43 _(TK_OPTION, "option", "") \ 44 _(TK_APPLY, "apply", "") \ 45 _(TK_COMPREHENSION, "comprehension", "") \ 46 _(TK_RANGE_CONSTRAINT, "range_constraint", "") \ 47 _(TK_PARAM, "param", "") \ 48 _(TK_INFERRED, "inferred", "") \ 49 _(TK_ACCESS, "access", "") \ 50 _(TK_ASSIGN, "assign", "") \ 51 _(TK_AUG_ASSIGN, "aug_assign", "") \ 52 _(TK_ATTRIBUTE, "attribute", "") \ 53 _(TK_IF, "if", "if") \ 54 _(TK_ELSE, "else", "else") \ 55 _(TK_ELIF, "elif", "elif") \ 56 _(TK_WHILE, "while", "while") \ 57 _(TK_EXPR_STMT, "expression statement", "") \ 58 _(TK_RETURN, "return", "return") \ 59 _(TK_IS, "is", "is") \ 60 _(TK_ISNOT, "is not", "is not") \ 61 _(TK_NE, "ne", "!=") \ 62 _(TK_EQ, "eq", "==") \ 63 _(TK_LE, "le", "<=") \ 64 _(TK_GE, "ge", ">=") \ 65 _(TK_FLOOR_DIV, "floordiv", "//") \ 66 _(TK_IF_EXPR, "if", "") \ 67 _(TK_TRUE, "True", "True") \ 68 _(TK_FALSE, "False", "False") \ 69 _(TK_NONE, "None", "None") \ 70 _(TK_AND, "and", "and") \ 71 _(TK_OR, "or", "or") \ 72 _(TK_NOT, "not", "not") \ 73 _(TK_LSHIFT, "<<", "<<") \ 74 _(TK_RSHIFT, ">>", ">>") \ 75 _(TK_CAST, "cast", "") \ 76 _(TK_PLUS_EQ, "+=", "+=") \ 77 _(TK_MINUS_EQ, "-=", "-=") \ 78 _(TK_TIMES_EQ, "*=", "*=") \ 79 _(TK_DIV_EQ, "/=", "/=") \ 80 _(TK_MOD_EQ, "%=", "%=") \ 81 _(TK_BIT_OR_EQ, "|=", "|=") \ 82 _(TK_BIT_AND_EQ, "&=", "&=") \ 83 _(TK_BIT_XOR_EQ, "^=", "^=") \ 84 _(TK_LSHIFT_EQ, "<<=", "<<=") \ 85 _(TK_RSHIFT_EQ, ">>=", ">>=") \ 86 _(TK_POW_EQ, "**=", "**=") \ 87 _(TK_GLOBAL, "global", "global") \ 88 _(TK_BUILT_IN, "built-in", "") \ 89 _(TK_SUBSCRIPT, "subscript", "") \ 90 _(TK_VAR, "variable", "") \ 91 _(TK_NOTHING, "nothing", "") \ 92 _(TK_DICT_LITERAL, "dict-literal", "") \ 93 _(TK_LIST_LITERAL, "list-literal", "") \ 94 _(TK_TUPLE_LITERAL, "tuple-literal", "") \ 95 _(TK_FOR, "for", "for") \ 96 _(TK_IN, "in", "in") \ 97 _(TK_NOTIN, "not in", "not in") \ 98 _(TK_STARRED, "starred", "") \ 99 _(TK_UNARY_MINUS, "unary minus", "") \ 100 _(TK_POW, "pow operator", "**") \ 101 _(TK_ARROW, "arrow", "->") \ 102 _(TK_DECL, "decl", "") \ 103 _(TK_SLICE_EXPR, "slice expr", "") \ 104 _(TK_TYPE_COMMENT, "type comment", "# type:") \ 105 _(TK_RAISE, "raise", "raise") \ 106 _(TK_ASSERT, "assert", "assert") \ 107 _(TK_DOTS, "dots", "...") \ 108 _(TK_LIST_COMP, "list comprehension", "") \ 109 _(TK_DICT_COMP, "dict comprehension", "") \ 110 _(TK_BREAK, "break", "break") \ 111 _(TK_CONTINUE, "continue", "continue") \ 112 _(TK_DELETE, "del", "del") \ 113 _(TK_PASS, "pass", "pass") \ 114 _(TK_CLASS_DEF, "class", "class") \ 115 _(TK_IMPORT, "import", "import") \ 116 _(TK_WITH, "with", "with") \ 117 _(TK_WITH_ITEM, "withitem", "") \ 118 _(TK_AS, "as", "as") \ 119 _(TK_PROP, "property", "") \ 120 _(TK_ELLIPSIS, "Ellipsis", "Ellipsis") \ 121 _(TK_NONE_TYPE, "NoneType", "NoneType") 122 123 enum TokenKind { 124 // we use characters to represent themselves so skip all valid characters 125 // before 126 // assigning enum values to multi-char tokens. 127 TK_DUMMY_START = 256, 128 #define DEFINE_TOKEN(tok, _, _2) tok, 129 TC_FORALL_TOKEN_KINDS(DEFINE_TOKEN) 130 #undef DEFINE_TOKEN 131 }; 132 133 TORCH_API std::string kindToString(int kind); 134 TORCH_API int stringToKind(const std::string& str); 135 136 // nested hash tables that indicate char-by-char what is a valid token. 137 struct TokenTrie; 138 using TokenTrieRef = std::unique_ptr<TokenTrie>; 139 struct TokenTrie { 140 TokenTrie() = default; insertTokenTrie141 void insert(const char* str, int tok) { 142 if (*str == '\0') { 143 AT_ASSERT(kind == 0); 144 kind = tok; 145 return; 146 } 147 148 for (size_t i = 0, e = child_chars.size(); i < e; ++i) { 149 if (child_chars[i] == *str) { 150 child_tries[i]->insert(str + 1, tok); 151 return; 152 } 153 } 154 155 child_chars.emplace_back(*str); 156 child_tries.emplace_back(std::make_unique<TokenTrie>()); 157 child_tries.back()->insert(str + 1, tok); 158 } 159 int kind{0}; // 0 == invalid token 160 161 std::vector<char> child_chars; 162 std::vector<TokenTrieRef> child_tries; 163 }; 164 165 // stuff that is shared against all TC lexers/parsers and is initialized only 166 // once. 167 struct TORCH_API SharedParserData { SharedParserDataSharedParserData168 SharedParserData() : head(new TokenTrie()) { 169 for (const char* c = valid_single_char_tokens; *c; c++) { 170 std::string str(1, *c); 171 head->insert(str.c_str(), *c); 172 } 173 174 #define ADD_CASE(tok, _, tokstring) \ 175 if (*(tokstring) != '\0') { \ 176 head->insert((tokstring), (tok)); \ 177 } 178 TC_FORALL_TOKEN_KINDS(ADD_CASE) 179 #undef ADD_CASE 180 } 181 matchSharedParserData182 bool match( 183 StringCordView::Iterator pos, 184 bool continuation, // are we inside a scope where newlines don't count 185 // (e.g. inside parens) 186 bool whitespace_token, // should we treat whitespace as a token 187 int* kind, 188 StringCordView::Iterator* start, 189 StringCordView::Iterator* end) { 190 *start = pos; 191 // skip whitespace 192 while (pos.has_next() && isblank(*pos)) { 193 ++pos; 194 } 195 196 // special handling 197 if (pos.has_next()) { 198 if (*pos == '#' && !isTypeComment(pos)) { 199 // skip comments 200 while (pos.has_next() && *pos != '\n') 201 ++pos; 202 // tail call, handle whitespace and more comments 203 return match(pos, continuation, whitespace_token, kind, start, end); 204 } 205 if (*pos == '\\') { 206 auto newiter = pos; 207 ++newiter; 208 if (newiter.has_next() && *newiter == '\n' && !whitespace_token) { 209 ++newiter; 210 return match(newiter, continuation, false, kind, start, end); 211 } 212 } 213 if (*pos == '\n') { 214 return match(++pos, continuation, !continuation, kind, start, end); 215 } 216 } 217 // we handle white space before EOF because in the case we have something 218 // like the following where we need to generate the dedent token if foo: 219 // ... 220 // else: 221 // pass 222 if (whitespace_token) { 223 *kind = !pos.has_next() ? TK_WHITESPACE_EOF : TK_WHITESPACE; 224 *end = pos; 225 return true; 226 } 227 if (!pos.has_next()) { 228 *kind = TK_EOF; 229 *start = pos; 230 *end = *start; 231 return true; 232 } 233 // invariant: the next token is not whitespace or newline 234 *start = pos; 235 // check for a valid number 236 size_t len = 0; 237 if (isNumber(pos.rest_line(), 0, &len)) { 238 *end = *start; 239 *end += len; 240 *kind = TK_NUMBER; 241 return true; 242 } 243 // check for string 244 if (isString(pos.rest_line(), 0, &len)) { 245 *kind = TK_STRINGLITERAL; 246 *end = *start; 247 *end += len; 248 return true; 249 } 250 251 // check for either an ident or a token 252 // ident tracks whether what we have scanned so far could be an identifier 253 // matched indicates if we have found any match. 254 bool matched = false; 255 bool ident = true; 256 TokenTrie* cur = head.get(); 257 // for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr); 258 // i++) 259 for (size_t i = 0; pos.has_next() && (ident || cur != nullptr); 260 ++pos, ++i) { 261 ident = ident && validIdent(i, *pos); 262 if (ident) { 263 matched = true; 264 *end = pos.next_iter(); 265 *kind = TK_IDENT; 266 } 267 // check for token second, so that e.g. 'max' matches the token TK_MAX 268 // rather the 269 // identifier 'max' 270 if (cur) { 271 const auto begin_it = cur->child_chars.begin(); 272 const auto end_it = cur->child_chars.end(); 273 const auto ch_it = std::find(begin_it, end_it, *pos); 274 275 cur = (ch_it == end_it) ? nullptr 276 : cur->child_tries[ch_it - begin_it].get(); 277 278 if (cur && cur->kind != 0) { 279 matched = true; 280 *end = pos.next_iter(); 281 *kind = cur->kind; 282 } 283 } 284 } 285 return matched; 286 } 287 288 bool isUnary(int kind, int* prec); 289 bool isBinary(int kind, int* prec); isRightAssociativeSharedParserData290 bool isRightAssociative(int kind) { 291 switch (kind) { 292 case '?': 293 case TK_POW: 294 case TK_IF: 295 return true; 296 default: 297 return false; 298 } 299 } 300 301 private: validIdentSharedParserData302 bool validIdent(size_t i, char n) { 303 return isalpha(n) || n == '_' || (i > 0 && isdigit(n)); 304 } 305 306 // 1. skip whitespace 307 // 2. handle comment or newline 308 // isNumberSharedParserData309 bool isNumber(c10::string_view str, size_t start, size_t* len) { 310 char first = str[start]; 311 // strtod allows numbers to start with + or - or nan or inf 312 // http://en.cppreference.com/w/cpp/string/byte/strtof 313 // but we want only the number part, otherwise 1+3 will turn into two 314 // adjacent numbers in the lexer 315 if (first == '-' || first == '+' || isalpha(first)) 316 return false; 317 const char* startptr = str.data() + start; 318 char* endptr = nullptr; 319 torch::jit::strtod_c(startptr, &endptr); 320 *len = endptr - startptr; 321 // check if the number is complex valued 322 // access is safe because string is assumed to be null terminated 323 if (endptr != nullptr && *endptr == 'j') { 324 *len += 1; 325 } 326 return *len > 0; 327 } 328 isCharCountSharedParserData329 bool isCharCount(char c, c10::string_view str, size_t start, int len) { 330 // count checks from [start, start + len) 331 return start + len <= str.size() && 332 std::count(str.begin() + start, str.begin() + start + len, c) == len; 333 } 334 335 // python concatenates all adjacent strings "a" "b" == "ab" 336 // strings can be enclosed with 1 or 3 single or double quotes 337 // if enclosed with 3 quotes newlines are valid 338 // as elsewhere, backslash and new line should be ignored isStringSharedParserData339 bool isString(c10::string_view str, size_t start, size_t* len) { 340 char quote = str[start]; 341 if (quote != '\"' && quote != '\'') 342 return false; 343 int quote_len = isCharCount(quote, str, start, 3) ? 3 : 1; 344 345 // end is now set past the opening quotation marks 346 size_t end = start + quote_len; 347 while (end < str.size() && !isCharCount(quote, str, end, quote_len)) { 348 if (str[end] == '\n' && quote_len != 3) { 349 return false; 350 } 351 // handle escaped characters. advances past escaped quotation marks, 352 // escaped newlines and escaped backslashes 353 // multi-char escapes like \x1A are handled fine here because the 354 // remainder of the escape are valid string characters anyway 355 if (str[end] == '\\') { 356 end++; 357 } 358 end++; 359 } 360 // set length equal to the complete string including quotations 361 *len = end - start + quote_len; 362 // if end finished without going past the last character of the string than 363 // there is a match 364 return end < str.size(); 365 } 366 isblankSharedParserData367 bool isblank(int n) { 368 return isspace(n) && n != '\n'; 369 } 370 isTypeCommentSharedParserData371 bool isTypeComment(StringCordView::Iterator str_iter) { 372 c10::string_view rest_line = str_iter.rest_line(); 373 const std::string type_string = "# type:"; 374 if (rest_line.size() < type_string.length()) { 375 return false; 376 } 377 auto match_string = rest_line.substr(0, type_string.size()); 378 return match_string == type_string; 379 } 380 381 // Make an exception ignoring comments for type annotation comments isTypeCommentSharedParserData382 bool isTypeComment(const StringCordView& str, size_t pos) { 383 const std::string type_string = "# type:"; 384 if (str.size() < pos + type_string.length()) { 385 return false; 386 } 387 auto match_string = str.substr(pos, type_string.size()); 388 return match_string == type_string; 389 } 390 391 TokenTrieRef head; 392 }; 393 394 TORCH_API SharedParserData& sharedParserData(); 395 396 struct Token { 397 int kind; 398 SourceRange range; TokenToken399 Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {} textToken400 std::string text() { 401 return std::string(range.token_text()); 402 } kindStringToken403 std::string kindString() const { 404 return kindToString(kind); 405 } 406 }; 407 408 struct Lexer { LexerLexer409 explicit Lexer(std::shared_ptr<Source> source) 410 : source(std::move(source)), 411 412 indent_stack(), 413 next_tokens(), 414 shared(sharedParserData()) { 415 auto first_indent = lexRaw(true); 416 indent_stack.push_back(first_indent.range.size()); 417 lex(); 418 } 419 // Return the current token, and then move to the next one nextLexer420 Token next() { 421 if (next_tokens.empty()) 422 reportError("Lexer invariant violated: empty token queue"); 423 Token r = std::move(next_tokens.front()); 424 next_tokens.erase(next_tokens.begin()); 425 if (next_tokens.empty()) { 426 lex(); 427 } 428 return r; 429 } 430 // Skip the current token if it matches the given kind nextIfLexer431 bool nextIf(int kind) { 432 if (cur().kind != kind) 433 return false; 434 next(); 435 return true; 436 } 437 reportErrorLexer438 [[noreturn]] void reportError(const std::string& what) { 439 reportError(what, cur()); 440 } reportErrorLexer441 [[noreturn]] void reportError(const std::string& what, const Token& t) { 442 std::stringstream ss; 443 ss << what << ":\n"; 444 t.range.highlight(ss); 445 throw std::runtime_error(ss.str()); 446 } expectedLexer447 [[noreturn]] void expected(const std::string& what, const Token& t) { 448 std::stringstream ss; 449 ss << "expected " << what << " but found '" << t.kindString() 450 << "' here:\n"; 451 t.range.highlight(ss); 452 throw std::runtime_error(ss.str()); 453 } expectedLexer454 [[noreturn]] void expected(const std::string& what) { 455 expected(what, cur()); 456 } 457 // Check that the current token has a given kind, return the current token, 458 // and advance to the next one. expectLexer459 Token expect(int kind) { 460 if (cur().kind != kind) { 461 expected(kindToString(kind)); 462 } 463 return next(); 464 } lookaheadLexer465 Token& lookahead() { 466 if (next_tokens.size() < 2) { 467 lex(); 468 } 469 return next_tokens[1]; 470 } curLexer471 Token& cur() { 472 return next_tokens.front(); 473 } 474 475 private: lexLexer476 void lex() { 477 auto r = lexRaw(); 478 switch (r.kind) { 479 case '(': 480 case '[': 481 case '{': 482 nesting++; 483 break; 484 case ')': 485 case ']': 486 case '}': 487 nesting--; 488 break; 489 case TK_WHITESPACE: 490 case TK_WHITESPACE_EOF: { 491 const auto depth = 492 r.kind == TK_WHITESPACE_EOF ? indent_stack.front() : r.range.size(); 493 // note: TK_WHITESPACE_EOF is whitespace right before the EOF token 494 // just like we allow the code to be indented to a particular initial 495 // indent level, we allow the final indent to be anything and set 496 // it back to the initial indent level. This allows the code to be 497 // put into string literals inside code without worrying about final 498 // whitespace 499 if (depth > indent_stack.back()) { 500 indent_stack.push_back(depth); 501 r.kind = TK_INDENT; 502 } else if (depth == indent_stack.back()) { 503 r.kind = TK_NEWLINE; 504 } else { 505 next_tokens.emplace_back(TK_NEWLINE, r.range); 506 while (indent_stack.back() != depth) { 507 indent_stack.pop_back(); 508 next_tokens.emplace_back(TK_DEDENT, r.range); 509 if (indent_stack.empty()) { 510 reportError("invalid indent level " + std::to_string(depth), r); 511 } 512 } 513 return; // We've already queued the tokens 514 } 515 } break; 516 default: 517 break; 518 } 519 next_tokens.push_back(std::move(r)); 520 } 521 Token lexRaw(bool whitespace_token = false) { 522 AT_ASSERT(source); 523 if (current == nullptr) { 524 AT_ASSERT(pos == 0); 525 current = std::make_unique<StringCordView::Iterator>( 526 source->text_str().begin()); 527 } 528 529 StringCordView::Iterator start_iter = *current; 530 StringCordView::Iterator end_iter = *current; 531 int kind = 0; 532 if (!shared.match( 533 *current, 534 nesting > 0, 535 whitespace_token, 536 &kind, 537 &start_iter, 538 &end_iter)) { 539 expected( 540 "a valid token", 541 Token( 542 **current, 543 SourceRange(source, start_iter, start_iter.pos() + 1))); 544 } 545 546 auto t = Token(kind, SourceRange(source, start_iter, end_iter.pos())); 547 pos = end_iter.pos(); 548 *current = end_iter; 549 return t; 550 } 551 552 std::shared_ptr<Source> source; 553 std::unique_ptr<StringCordView::Iterator> current; 554 size_t pos{0}; 555 size_t nesting{0}; // depth of ( [ { nesting... 556 std::vector<size_t> indent_stack; // stack of indentation level of blocks 557 // Invariant: this should always contain at least a single element 558 std::vector<Token> next_tokens; 559 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 560 SharedParserData& shared; 561 }; 562 } // namespace torch::jit 563