1# mypy: allow-untyped-defs 2import ast 3import dataclasses 4import inspect 5import re 6import string 7import sys 8from collections import namedtuple 9from textwrap import dedent 10from typing import List, Tuple # noqa: F401 11 12import torch 13import torch.jit.annotations 14from torch import _jit_internal 15from torch._C._jit_tree_views import ( 16 Apply, 17 Assert, 18 Assign, 19 Attribute, 20 AugAssign, 21 BinOp, 22 Break, 23 ClassDef, 24 Const, 25 Continue, 26 Decl, 27 Def, 28 Delete, 29 DictComp, 30 DictLiteral, 31 Dots, 32 EmptyTypeAnnotation, 33 ExprStmt, 34 FalseLiteral, 35 For, 36 Ident, 37 If, 38 ListComp, 39 ListLiteral, 40 NoneLiteral, 41 Param, 42 Pass, 43 Property, 44 Raise, 45 Return, 46 Select, 47 SliceExpr, 48 Starred, 49 Stmt, 50 StringLiteral, 51 Subscript, 52 TernaryIf, 53 TrueLiteral, 54 TupleLiteral, 55 UnaryOp, 56 Var, 57 While, 58 With, 59 WithItem, 60) 61from torch._jit_internal import ( # noqa: F401 62 _is_drop_fn, 63 FunctionModifiers, 64 is_static_fn, 65 should_drop, 66) 67from torch._sources import ( 68 get_source_lines_and_file, 69 make_source_context, 70 parse_def, 71 ParsedDef as _ParsedDef, 72) 73from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS 74from torch.jit._monkeytype_config import get_qualified_name, monkeytype_trace 75 76 77_IS_ASTUNPARSE_INSTALLED = False 78try: 79 import astunparse # type: ignore[import] 80 81 _IS_ASTUNPARSE_INSTALLED = True 82except ImportError: 83 pass 84 85# Borrowed from cPython implementation 86# https://github.com/python/cpython/blob/561612d8456cfab5672c9b445521113b847bd6b3/Lib/textwrap.py#L411# 87 88_reserved_prefix = "__jit" 89_reserved_names = {"print"} 90_identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits) 91 92 93def is_reserved_name(name): 94 return name.startswith(_reserved_prefix) or name in _reserved_names 95 96 97pretty_node_names = { 98 ast.FunctionDef: "function definitions", 99 ast.For: "for loops", 100 ast.Delete: "del statements", 101 ast.ClassDef: "class definitions", 102 ast.With: "with statements", 103 ast.Raise: "raise statements", 104 ast.Assert: "assertions", 105 ast.Import: "import statements", 106 ast.ImportFrom: "import statements", 107 ast.Global: "global variables", 108 ast.Break: "break statements", 109 ast.Continue: "continue statements", 110} 111 112node_start_tokens = { 113 ast.FunctionDef: "def", 114 ast.For: "for", 115 ast.Delete: "del", 116 ast.ClassDef: "class", 117 ast.With: "with", 118 ast.Raise: "raise", 119 ast.Assert: "assert", 120 ast.Import: "import", 121 ast.ImportFrom: "from", 122 ast.Global: "global", 123 ast.Break: "break", 124 ast.Continue: "continue", 125} 126 127pretty_node_names.update( 128 { 129 ast.AsyncFunctionDef: "async function definitions", 130 ast.AsyncFor: "async for loops", 131 ast.AsyncWith: "async with statements", 132 ast.Try: "try blocks", 133 ast.Nonlocal: "nonlocal variables", 134 } 135) 136 137node_start_tokens.update( 138 { 139 ast.AsyncFunctionDef: "async def", 140 ast.AsyncFor: "async for", 141 ast.AsyncWith: "async with", 142 ast.Try: "try", 143 ast.Nonlocal: "nonlocal", 144 } 145) 146 147pretty_node_names.update( 148 { 149 ast.AnnAssign: "annotated assignments", 150 } 151) 152# NB: no specific token for AnnAssign 153 154 155class FrontendError(Exception): 156 def __init__(self, source_range, msg): 157 self.source_range = source_range 158 self.msg = msg 159 160 # This has to be instantiated here so the ErrorReport is accurate to the 161 # call stack when the FrontendError was raised 162 self.error_report = torch._C.ErrorReport(self.source_range) 163 164 def __str__(self): 165 return self.msg + self.error_report.what().lstrip() 166 167 168class NotSupportedError(FrontendError): 169 pass 170 171 172class UnsupportedNodeError(NotSupportedError): 173 def __init__(self, ctx, offending_node, reason=""): 174 # If we don't have a specific token, we default to length of 1 175 node_type = type(offending_node) 176 range_len = len(node_start_tokens.get(node_type, " ")) 177 source_range = ctx.make_range( 178 offending_node.lineno, 179 offending_node.col_offset, 180 offending_node.col_offset + range_len, 181 ) 182 feature_name = pretty_node_names.get(node_type, node_type.__name__) 183 msg = f"{feature_name} {reason + ' ' if reason else ''}aren't supported" 184 super().__init__(source_range, msg) 185 186 187class FrontendTypeError(FrontendError): 188 pass 189 190 191def build_withitems(ctx, items): 192 items = [build_withitem(ctx, i) for i in items] 193 return list(items) 194 195 196def build_stmts(ctx, stmts): 197 stmts = [build_stmt(ctx, s) for s in stmts] 198 return list(filter(None, stmts)) 199 200 201def get_class_properties(cls, self_name): 202 """ 203 Get a list of Property objects representing the properties of a class. 204 205 Args: 206 cls: The class to get properties of. 207 self_name: The name of the class that the properties should belong to. 208 Returns: 209 A list of Property objects corresponding to the properties of cls. Property 210 here refers to the subclass of TreeView. 211 """ 212 props = inspect.getmembers(cls, predicate=lambda m: isinstance(m, property)) 213 # Any property that should not compiled must be in this list on the Module. 214 unused_properties = getattr(cls, "__jit_unused_properties__", []) 215 216 # Create Property TreeView objects from inspected property objects. 217 properties = [] 218 for prop in props: 219 if prop[0] not in unused_properties and not should_drop(prop[1].fget): 220 getter = get_jit_def( 221 prop[1].fget, f"__{prop[0]}_getter", self_name=self_name 222 ) 223 setter = ( 224 get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name) 225 if prop[1].fset 226 else None 227 ) 228 properties.append( 229 Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter) 230 ) 231 232 return properties 233 234 235def get_class_assigns(ctx, cls_ast): 236 assigns = [] 237 238 def maybe_build_assign(builder, entry): 239 nonlocal assigns 240 try: 241 assigns.append(builder(ctx, entry)) 242 except NotSupportedError: 243 pass 244 245 for entry in cls_ast.body: 246 if isinstance(entry, ast.Assign): 247 maybe_build_assign(StmtBuilder.build_Assign, entry) 248 elif isinstance(entry, ast.AnnAssign): 249 maybe_build_assign(StmtBuilder.build_AnnAssign, entry) 250 return assigns 251 252 253def get_jit_class_def(cls, self_name): 254 """Get definitions for each method within the current class independently. 255 256 Args: 257 cls: The class to get definition of. 258 self_name: The name of the class that the properties should belong to. 259 260 Returns: 261 torch._C._jit_tree_views.ClassDef: A representation of the class, 262 the methods in the class and their definition as a tree. 263 """ 264 # TODO: proper overriding analysis when implementing class inheritance 265 methods = inspect.getmembers( 266 cls, 267 predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) 268 and not is_static_fn(cls, m.__name__) 269 and m.__name__ in cls.__dict__ 270 and not _is_drop_fn(m), 271 ) 272 273 def is_classmethod(fn): 274 return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls 275 276 # Get and parse the source code for this class 277 sourcelines, file_lineno, filename = get_source_lines_and_file( 278 cls, torch._C.ErrorReport.call_stack() 279 ) 280 source = "".join(sourcelines) 281 282 dedent_src = dedent(source) 283 py_ast = ast.parse(dedent_src) 284 285 class_ast = py_ast.body[0] 286 assert isinstance(class_ast, ast.ClassDef) 287 288 # Special case for dataclasses. In general we need access to the source code for 289 # an object in order to JIT compile it. But the dataclasses module dynamically synthesizes 290 # magic methods for classes, and we can't get the source code for these methods. As a 291 # workaround, we synthesize TorchScript-friendly implementations ourselves. 292 if dataclasses.is_dataclass(cls): 293 # Detect whether the user manually implemented any of the magic methods. If they did, 294 # we don't want to synthesize/override them. 295 overrides = { 296 method.name 297 for method in class_ast.body 298 if isinstance(method, ast.FunctionDef) 299 and method.name in DATACLASS_MAGIC_METHODS 300 } 301 for i, (name, _) in enumerate(methods): 302 # Is this a magic method we can synthesize? 303 synthesizer_fn = DATACLASS_MAGIC_METHODS.get(name) 304 if synthesizer_fn and name not in overrides: 305 parsed_def = synthesizer_fn(cls) 306 methods[i] = name, parsed_def 307 func = getattr(cls, name) 308 _jit_internal.loader.cache(func, parsed_def.source) 309 310 method_defs = [ 311 get_jit_def(obj, name, self_name=self_name, is_classmethod=is_classmethod(obj)) 312 for (name, obj) in methods 313 ] 314 properties = get_class_properties(cls, self_name) 315 316 leading_whitespace_len = len(source.split("\n", 1)[0]) - len( 317 dedent_src.split("\n", 1)[0] 318 ) 319 ctx = make_source_context( 320 source, filename, file_lineno, leading_whitespace_len, False 321 ) 322 assigns = get_class_assigns(ctx, class_ast) 323 324 return build_class_def(ctx, class_ast, method_defs, properties, self_name, assigns) 325 326 327def get_jit_def(fn, def_name, self_name=None, is_classmethod=False): 328 """ 329 Build a JIT AST (TreeView) from the given function. 330 331 Args: 332 fn: A function object to compile or a pre-parsed ParsedDef object 333 def_name: The name to give to the resulting AST object. This is not 334 always the same as `fn.__name__`, for example: 335 def _forward(self): 336 ... 337 forward = _forward 338 In this case, the `__name__` attribute of the function object is "_forward", 339 but we want the result AST to have the name "forward". 340 self_name: If this function is a method, what the type name of `self` is. 341 """ 342 parsed_def = parse_def(fn) if not isinstance(fn, _ParsedDef) else fn 343 type_line = torch.jit.annotations.get_type_line(parsed_def.source) 344 fn_def = parsed_def.ast.body[0] 345 346 if is_classmethod: 347 arg_name = fn_def.args.args[0].arg 348 # Insert a statement that assigns the first argument to the class 349 assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0] 350 fn_def.body.insert(0, assign_stmt) 351 352 # Swap out the function signature and body if it is unused 353 if should_drop(fn): 354 unused_fn_def = ast.parse( 355 'def unused_fn(self: Any):\n\traise RuntimeError("Cannot call @unused methods")' 356 ) 357 if len(unused_fn_def.body) != 1 or not isinstance( 358 unused_fn_def.body[0], ast.FunctionDef 359 ): 360 raise RuntimeError( 361 f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}" 362 ) 363 unused_def = unused_fn_def.body[0] 364 fn_def.body = unused_def.body 365 # kwarg/vararg not supported by `build_def` 366 fn_def.args.kwarg = fn_def.args.vararg = None 367 for arg in fn_def.args.args + fn_def.args.kwonlyargs: 368 # Replace potentially unsupported type annotations by "Any" 369 arg.annotation = unused_def.args.args[0].annotation 370 if _is_drop_fn(fn): 371 # Dropping potentially unsupported return type annotation for jit._drop 372 fn_def.returns = None 373 fn_def.type_comment = None 374 375 # If MonkeyType is installed, get all the consolidated type traces 376 # for the arguments from type_trace_db 377 type_trace_db = torch.jit._script._get_type_trace_db() 378 pdt_arg_types = None 379 if monkeytype_trace and not isinstance(fn, _ParsedDef): # type: ignore[truthy-function] 380 qualname = get_qualified_name(fn) 381 pdt_arg_types = type_trace_db.get_args_types(qualname) 382 383 return build_def( 384 parsed_def.ctx, 385 fn_def, 386 type_line, 387 def_name, 388 self_name=self_name, 389 pdt_arg_types=pdt_arg_types, 390 ) 391 392 393# TODO: more robust handling of recognizing ignore context manager 394def is_torch_jit_ignore_context_manager(stmt): 395 # checks if the statement is torch.jit.ignore context manager 396 if isinstance(stmt.items[0].context_expr, ast.Call): 397 # extract torch part 398 function = stmt.items[0].context_expr.func 399 if isinstance(function, ast.Attribute): 400 attr_name = function.attr 401 attr_value = function.value 402 if attr_name == "_IgnoreContextManager" and isinstance( 403 attr_value, ast.Attribute 404 ): 405 # there should be at most two nested attributes (e.g torch.jit._IgnoreContextManager) 406 if attr_value.attr == "jit" and isinstance(attr_value.value, ast.Name): 407 if attr_value.value.id == "torch": 408 return True 409 return False 410 411 412class Builder: 413 def __call__(self, ctx, node): 414 method = getattr(self, "build_" + node.__class__.__name__, None) 415 if method is None: 416 raise UnsupportedNodeError(ctx, node) 417 return method(ctx, node) 418 419 420def build_class_def(ctx, py_def, methods, properties, self_name, assigns): 421 r = ctx.make_range( 422 py_def.lineno, py_def.col_offset, py_def.col_offset + len("class") 423 ) 424 return ClassDef( 425 Ident(r, self_name), [Stmt(method) for method in methods], properties, assigns 426 ) 427 428 429def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=None): 430 body = py_def.body 431 r = ctx.make_range(py_def.lineno, py_def.col_offset, py_def.col_offset + len("def")) 432 433 param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types) 434 return_type = None 435 if getattr(py_def, "returns", None) is not None: 436 return_type = build_expr(ctx, py_def.returns) 437 438 decl = Decl(r, param_list, return_type) 439 is_method = self_name is not None 440 if type_line is not None: 441 type_comment_decl = torch._C.parse_type_comment(type_line) 442 decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method) 443 444 return Def(Ident(r, def_name), decl, build_stmts(ctx, body)) 445 446 447_vararg_kwarg_err = ( 448 "Compiled functions can't take variable number of arguments " 449 "or use keyword-only arguments with defaults" 450) 451 452 453def build_param_list(ctx, py_args, self_name, pdt_arg_types=None): 454 if py_args.kwarg is not None: 455 expr = py_args.kwarg 456 ctx_range = ctx.make_range( 457 expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg) 458 ) 459 raise NotSupportedError(ctx_range, _vararg_kwarg_err) 460 if py_args.vararg is not None: 461 expr = py_args.vararg 462 ctx_range = ctx.make_range( 463 expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg) 464 ) 465 raise NotSupportedError(ctx_range, _vararg_kwarg_err) 466 if len(py_args.kw_defaults) > 0: 467 # kw_defaults is a list of the values for the kwargs (which default to None), 468 # so they don't actually have line numbers. 469 for arg in py_args.kw_defaults: 470 if arg is not None: 471 ctx_range = build_expr(ctx, arg).range() 472 raise NotSupportedError(ctx_range, _vararg_kwarg_err) 473 474 # List of Tuple of args and type as inferred by profile directed typing 475 arg_and_types = [ 476 ( 477 arg, 478 pdt_arg_types[arg.arg] 479 if pdt_arg_types and bool(pdt_arg_types[arg.arg]) 480 else None, 481 ) 482 for arg in py_args.args 483 ] 484 arg_and_types_kwonlyargs = [ 485 ( 486 arg, 487 pdt_arg_types[arg.arg] 488 if pdt_arg_types and bool(pdt_arg_types[arg.arg]) 489 else None, 490 ) 491 for arg in py_args.kwonlyargs 492 ] 493 494 result = [ 495 build_param(ctx, arg, self_name, kwarg_only=False, pdt_arg_type=arg_type) 496 for arg, arg_type in arg_and_types 497 ] 498 result += [ 499 build_param(ctx, arg, self_name, kwarg_only=True, pdt_arg_type=arg_type) 500 for arg, arg_type in arg_and_types_kwonlyargs 501 ] 502 return result 503 504 505def build_param(ctx, py_arg, self_name, kwarg_only, pdt_arg_type=None): 506 # NB: In Python3 py_arg is a pair of (str arg, expr? annotation) 507 name = py_arg.arg 508 r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name)) 509 if getattr(py_arg, "annotation", None) is not None: 510 annotation_expr = build_expr(ctx, py_arg.annotation) 511 elif pdt_arg_type: 512 annotation_expr = Var(Ident(r, pdt_arg_type)) 513 elif self_name is not None and name == "self": 514 annotation_expr = Var(Ident(r, self_name)) 515 else: 516 annotation_expr = EmptyTypeAnnotation(r) 517 return Param(annotation_expr, Ident(r, name), kwarg_only) 518 519 520def build_ignore_context_manager(ctx, stmt): 521 InputType = namedtuple("InputType", ["name", "ann"]) 522 OutputType = namedtuple("OutputType", ["name", "ann"]) 523 524 def process_ins_outs(args): 525 # parse the context manager to figure out inputs and outputs 526 # with their annotated types 527 # TODO: add input, output validator 528 inputs = [] 529 outputs = [] 530 for arg in args: 531 var_name = arg.arg 532 var_ann = arg.value.value 533 var_decl_type, var_ann = var_ann.split(":") 534 if var_decl_type == "inp": 535 inputs.append(InputType(var_name, var_ann)) 536 if var_decl_type == "out": 537 outputs.append(OutputType(var_name, var_ann)) 538 return inputs, outputs 539 540 def create_unique_name_ext(ctx, stmt): 541 # extension will be based on the full path filename plus 542 # the line number of original context manager 543 fn = re.sub(r"[^a-zA-Z0-9_]", "_", ctx.filename) 544 return f"{fn}_{stmt.lineno}" 545 546 def build_return_ann_stmt(outputs): 547 return_type_ann = "" 548 return_statement_str = "return " 549 if len(outputs) == 0: 550 return_type_ann += " -> None" 551 if len(outputs) == 1: 552 return_type_ann = " -> " + outputs[0].ann 553 return_statement_str += outputs[0].name 554 if len(outputs) > 1: 555 return_type_ann = " -> Tuple" 556 return_type_ann += "[" + ", ".join([var.ann for var in outputs]) + "]" 557 return_statement_str += ", ".join([var.name for var in outputs]) 558 return return_type_ann, return_statement_str 559 560 def build_args(args): 561 return ", ".join([arg.name for arg in args]) 562 563 inputs, outputs = process_ins_outs(stmt.items[0].context_expr.keywords) 564 565 # build the replacement function str with given inputs and outputs 566 ignore_function_name = "func_ignore_" + create_unique_name_ext(ctx, stmt) 567 ignore_function_str = "\ndef " + ignore_function_name 568 ignore_function_str += ( 569 "(" + ", ".join([var.name + " :" + var.ann for var in inputs]) + ")" 570 ) 571 572 return_ann, return_stmt = build_return_ann_stmt(outputs) 573 ignore_function_str += return_ann + ": pass" 574 575 # first create the functionDef object from just declaration 576 ignore_function = ast.parse(ignore_function_str).body[0] 577 578 # dump the body of context manager to dummy function 579 ignore_function.body = stmt.body # type: ignore[attr-defined] 580 581 # insert return statement to the function 582 return_stmt = ast.parse(return_stmt).body[0] 583 ignore_function.body.append(return_stmt) # type: ignore[attr-defined] 584 585 # registers the custom function in the global context 586 ignore_func_str = "@torch.jit.ignore\n" + astunparse.unparse(ignore_function) 587 ignore_func_str += f'\nglobals()["{ignore_function_name}"] = {ignore_function_name}' 588 exec(ignore_func_str) # noqa: P204 589 590 # build the statements as: 591 # <out_1>, <out_2>, ... = torch.jit.frontend.<func>(<in_1>, <in_2>) 592 assign_str_lhs = build_args(outputs) 593 # this function will be registered in torch.jit.frontend module by default 594 assign_str_rhs = ( 595 f"torch.jit.frontend.{ignore_function_name}(" + build_args(inputs) + ")" 596 ) 597 598 if len(outputs) > 0: 599 assign_str = assign_str_lhs + " = " + assign_str_rhs 600 else: 601 assign_str = assign_str_rhs 602 assign_ast = ast.parse(assign_str).body[0] 603 return assign_ast 604 605 606def get_default_args(fn): 607 """ 608 Get a dictionary of default arguments for a function. 609 610 Args: 611 fn: Callable - The function to inspect for default arguments. 612 Returns: 613 (Dict[str, Any]): mapping argument names to their default values if 614 :attr:`fn` is not None, else empty dictionary. 615 """ 616 if fn is None: 617 return {} 618 619 signature = inspect.signature(fn) 620 621 return { 622 k: v.default 623 for k, v in signature.parameters.items() 624 if v.default is not inspect.Parameter.empty 625 } 626 627 628def get_default_args_for_class(cls): 629 """ 630 Get default arguments for all methods in a class (except for static methods). 631 632 Args: 633 cls: type - The class type to inspect for default arguments. 634 Returns: 635 A Dict[str, Dict[str, Any]] which maps each method name to a Dict[str, Any] 636 that maps each argument name to its default value. 637 """ 638 # Get methods (except static methods because those are compiled separately as 639 # if they were independent script functions). 640 methods = inspect.getmembers( 641 cls, 642 predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) 643 and not is_static_fn(cls, m.__name__) 644 and m.__name__ in cls.__dict__, 645 ) 646 647 # Get method defaults. Property defaults do not need to be considered 648 # because setters cannot be invoked without a value. 649 defaults = { 650 method_name: get_default_args(method_impl) 651 for method_name, method_impl in methods 652 } 653 654 return defaults 655 656 657class WithItemBuilder(Builder): 658 @staticmethod 659 def build_withitem(ctx, item): 660 lineno = item.context_expr.lineno 661 start = item.context_expr.col_offset 662 end = start + len(pretty_node_names[ast.With]) 663 op_vars = item.optional_vars 664 r = ctx.make_range(lineno, start, end) 665 666 return WithItem( 667 r, 668 build_expr(ctx, item.context_expr), 669 build_expr(ctx, op_vars) if op_vars else None, 670 ) 671 672 673class StmtBuilder(Builder): 674 augassign_map = { 675 ast.Add: "+", 676 ast.Sub: "-", 677 ast.Mult: "*", 678 ast.Div: "/", 679 ast.Mod: "%", 680 ast.BitOr: "|", 681 ast.BitAnd: "&", 682 ast.BitXor: "^", 683 ast.LShift: "<<", 684 ast.RShift: ">>", 685 ast.Pow: "**", 686 } 687 688 @staticmethod 689 def build_Expr(ctx, stmt): 690 value = stmt.value 691 if value.__class__.__name__ == "Str": 692 # If a statement is a string literal expression, 693 # then it is a docstring. Just ignore it. 694 return None 695 else: 696 return ExprStmt(build_expr(ctx, value)) 697 698 @staticmethod 699 def build_Assign(ctx, stmt): 700 rhs = build_expr(ctx, stmt.value) 701 lhs = [build_expr(ctx, x) for x in stmt.targets] 702 return Assign(lhs, rhs) 703 704 @staticmethod 705 def build_AnnAssign(ctx, stmt): 706 if stmt.value is None: 707 raise UnsupportedNodeError(ctx, stmt, reason="without assigned value") 708 709 # Disallow type annotations on instance attributes outside of __init__ 710 if ( 711 type(stmt.target) == ast.Attribute 712 and stmt.target.value.id == "self" # type: ignore[attr-defined] 713 and ctx.funcname != "__init__" 714 ): 715 start = stmt.col_offset 716 end = start + len(f"self.{stmt.target.attr}") 717 if hasattr(stmt.annotation, "id"): 718 end += len(f": {stmt.annotation.id}") 719 sr = ctx.make_range(stmt.lineno, start, end) 720 raise ValueError( 721 "Type annotations on instance attributes must be declared in " 722 f"__init__, not '{ctx.funcname}': {sr}" 723 ) 724 725 rhs = build_expr(ctx, stmt.value) 726 lhs = build_expr(ctx, stmt.target) 727 the_type = build_expr(ctx, stmt.annotation) 728 return Assign([lhs], rhs, the_type) 729 730 @staticmethod 731 def build_Delete(ctx, stmt): 732 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("del")) 733 734 return Delete(r, [build_expr(ctx, target) for target in stmt.targets]) 735 736 @staticmethod 737 def build_Return(ctx, stmt): 738 r = ctx.make_range( 739 stmt.lineno, stmt.col_offset, stmt.col_offset + len("return") 740 ) 741 return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value)) 742 743 @staticmethod 744 def build_Raise(ctx, stmt): 745 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("raise")) 746 expr = build_expr(ctx, stmt.exc) 747 return Raise(r, expr) 748 749 @staticmethod 750 def build_Assert(ctx, stmt): 751 r = ctx.make_range( 752 stmt.lineno, stmt.col_offset, stmt.col_offset + len("assert") 753 ) 754 test = build_expr(ctx, stmt.test) 755 msg = build_expr(ctx, stmt.msg) if stmt.msg is not None else None 756 return Assert(r, test, msg) 757 758 @staticmethod 759 def build_AugAssign(ctx, stmt): 760 lhs = build_expr(ctx, stmt.target) 761 rhs = build_expr(ctx, stmt.value) 762 op = type(stmt.op) 763 if op in StmtBuilder.augassign_map: 764 op_token = StmtBuilder.augassign_map[op] 765 else: 766 raise NotSupportedError( 767 find_before(ctx, rhs.range().start, "=", offsets=(-1, 0)), 768 "unsupported kind of augmented assignment: " + op.__name__, 769 ) 770 return AugAssign(lhs, op_token, rhs) 771 772 @staticmethod 773 def build_While(ctx, stmt): 774 if stmt.orelse: 775 # TODO: try to recover the location of else:? Python doesn't give us useful 776 # annotations in this case 777 raise NotSupportedError( 778 None, "else branches of while loops aren't supported" 779 ) 780 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("while")) 781 return While(r, build_expr(ctx, stmt.test), build_stmts(ctx, stmt.body)) 782 783 @staticmethod 784 def build_For(ctx, stmt): 785 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("for")) 786 if stmt.orelse: 787 raise NotSupportedError(r, "else branches of for loops aren't supported") 788 789 return For( 790 r, 791 [build_expr(ctx, stmt.target)], 792 [build_expr(ctx, stmt.iter)], 793 build_stmts(ctx, stmt.body), 794 ) 795 796 @staticmethod 797 def build_If(ctx, stmt): 798 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("if")) 799 return If( 800 r, 801 build_expr(ctx, stmt.test), 802 build_stmts(ctx, stmt.body), 803 build_stmts(ctx, stmt.orelse), 804 ) 805 806 @staticmethod 807 def build_Print(ctx, stmt): 808 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("print")) 809 if stmt.dest: 810 raise NotSupportedError( 811 r, "print statements with non-default destinations aren't supported" 812 ) 813 args = [build_expr(ctx, val) for val in stmt.values] 814 return ExprStmt(Apply(Var(Ident(r, "print")), args, [])) 815 816 @staticmethod 817 def build_Pass(ctx, stmt): 818 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("pass")) 819 return Pass(r) 820 821 @staticmethod 822 def build_Break(ctx, stmt): 823 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("break")) 824 return Break(r) 825 826 @staticmethod 827 def build_Continue(ctx, stmt): 828 r = ctx.make_range( 829 stmt.lineno, stmt.col_offset, stmt.col_offset + len("continue") 830 ) 831 return Continue(r) 832 833 @staticmethod 834 def build_With(ctx, stmt): 835 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("with")) 836 # Handle ignore context manager 837 if is_torch_jit_ignore_context_manager(stmt): 838 if not _IS_ASTUNPARSE_INSTALLED: 839 raise RuntimeError( 840 "torch.jit._IgnoreContextManager requires installing Python library `astunparse`, \ 841 please install it in your Python environment" 842 ) 843 assign_ast = build_ignore_context_manager(ctx, stmt) 844 return build_stmt(ctx, assign_ast) 845 return With(r, build_withitems(ctx, stmt.items), build_stmts(ctx, stmt.body)) 846 847 848class ExprBuilder(Builder): 849 binop_map = { 850 ast.Add: "+", 851 ast.Sub: "-", 852 ast.Mult: "*", 853 ast.Div: "/", 854 ast.Pow: "**", 855 ast.Mod: "%", 856 ast.FloorDiv: "//", 857 ast.BitAnd: "&", 858 ast.BitXor: "^", 859 ast.BitOr: "|", 860 ast.LShift: "<<", 861 ast.RShift: ">>", 862 } 863 864 binop_map[ast.MatMult] = "@" 865 866 unop_map = { 867 ast.Not: "not", 868 ast.USub: "-", 869 ast.Invert: "~", 870 } 871 872 boolop_map = { 873 ast.And: "and", 874 ast.Or: "or", 875 } 876 877 cmpop_map = { 878 ast.Eq: "==", 879 ast.NotEq: "!=", 880 ast.LtE: "<=", 881 ast.Lt: "<", 882 ast.GtE: ">=", 883 ast.Gt: ">", 884 ast.Is: "is", 885 ast.IsNot: "is not", 886 ast.In: "in", 887 ast.NotIn: "not in", 888 } 889 890 @staticmethod 891 def build_Attribute(ctx, expr): 892 base = build_expr(ctx, expr.value) 893 # expr.attr is just a string, so it's not annotated in any way, so we have 894 # to build the range manually 895 source = ctx.source.encode("utf-8") 896 897 def get_char(index): 898 return chr(source[index]) 899 900 start_pos = base.range().end + 1 901 while get_char(start_pos) in string.whitespace: # Skip whitespace 902 start_pos += 1 903 end_pos = start_pos + len(expr.attr) 904 name_range = ctx.make_raw_range(start_pos, end_pos) 905 return Select(base, Ident(name_range, expr.attr)) 906 907 @staticmethod 908 def build_Call(ctx, expr): 909 func = build_expr(ctx, expr.func) 910 args = [build_expr(ctx, py_arg) for py_arg in expr.args] 911 if hasattr(expr, "starargs") and expr.starargs: 912 stararg_expr = build_expr(ctx, expr.starargs) 913 args += [Starred(stararg_expr.range(), stararg_expr)] 914 kwargs = [] 915 for kw in expr.keywords: 916 kw_expr = build_expr(ctx, kw.value) 917 # XXX: we could do a better job at figuring out the range for the name here 918 if not kw.arg: 919 raise NotSupportedError( 920 kw_expr.range(), "keyword-arg expansion is not supported" 921 ) 922 kwargs.append(Attribute(Ident(kw_expr.range(), kw.arg), kw_expr)) 923 return Apply(func, args, kwargs) 924 925 @staticmethod 926 def build_Ellipsis(ctx, expr): 927 r = ctx.make_range( 928 expr.lineno, expr.col_offset, expr.col_offset + 3 929 ) # len("...") == 3 930 return Dots(r) 931 932 @staticmethod 933 def build_Name(ctx, expr): 934 r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id)) 935 if expr.id.startswith(_reserved_prefix): 936 raise NotSupportedError( 937 r, 938 "names of variables used in JIT-ed functions " 939 "can't start with " + _reserved_prefix, 940 ) 941 if expr.id == "True": 942 return TrueLiteral(r) 943 elif expr.id == "False": 944 return FalseLiteral(r) 945 elif expr.id == "None": 946 return NoneLiteral(r) 947 elif expr.id == "Ellipsis": 948 return Dots(r) 949 return Var(Ident(r, expr.id)) 950 951 @staticmethod 952 def build_NameConstant(ctx, expr): 953 r = ctx.make_range( 954 expr.lineno, expr.col_offset, expr.col_offset + len(str(expr.value)) 955 ) 956 if expr.value is True: 957 return TrueLiteral(r) 958 elif expr.value is False: 959 return FalseLiteral(r) 960 elif expr.value is None: 961 return NoneLiteral(r) 962 elif expr.value == Ellipsis: 963 return Dots(r) 964 else: 965 raise ValueError("Name constant value unsupported: " + str(expr.value)) 966 967 @staticmethod 968 def build_BinOp(ctx, expr): 969 lhs = build_expr(ctx, expr.left) 970 rhs = build_expr(ctx, expr.right) 971 op = type(expr.op) 972 973 if op == ast.Div and not ctx.uses_true_division: 974 err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start) 975 raise FrontendError( 976 err_range, 977 "Division of ints in TorchScript uses Python 3 true " 978 "division semantics. Please put `from __future__ " 979 "import division` at the top of your file", 980 ) 981 op_token = ExprBuilder.binop_map.get(op) 982 if op_token is None: 983 err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start) 984 raise NotSupportedError( 985 err_range, "unsupported binary operator: " + op.__name__ 986 ) 987 return BinOp(op_token, lhs, rhs) 988 989 @staticmethod 990 def build_UnaryOp(ctx, expr): 991 sub_expr = build_expr(ctx, expr.operand) 992 op = type(expr.op) 993 op_token = ExprBuilder.unop_map.get(op) 994 if op_token is None: 995 raise NotSupportedError( 996 expr.range(), "unsupported unary operator: " + op.__name__ 997 ) 998 r = ctx.make_range( 999 expr.lineno, expr.col_offset, expr.col_offset + len(op_token) 1000 ) 1001 return UnaryOp(r, op_token, sub_expr) 1002 1003 @staticmethod 1004 def build_BoolOp(ctx, expr): 1005 if len(expr.values) < 2: 1006 raise AssertionError( 1007 "expected at least 2 values in BoolOp, but got " + str(len(expr.values)) 1008 ) 1009 sub_exprs = [build_expr(ctx, sub_expr) for sub_expr in expr.values] 1010 op = type(expr.op) 1011 op_token = ExprBuilder.boolop_map.get(op) 1012 if op_token is None: 1013 err_range = ctx.make_raw_range( 1014 sub_exprs[0].range().end, sub_exprs[1].range().start 1015 ) 1016 raise NotSupportedError( 1017 err_range, "unsupported boolean operator: " + op.__name__ 1018 ) 1019 lhs = sub_exprs[0] 1020 for rhs in sub_exprs[1:]: 1021 lhs = BinOp(op_token, lhs, rhs) 1022 return lhs 1023 1024 @staticmethod 1025 def build_IfExp(ctx, expr): 1026 return TernaryIf( 1027 build_expr(ctx, expr.test), 1028 build_expr(ctx, expr.body), 1029 build_expr(ctx, expr.orelse), 1030 ) 1031 1032 @staticmethod 1033 def build_Compare(ctx, expr): 1034 operands = [build_expr(ctx, e) for e in [expr.left] + list(expr.comparators)] 1035 result = None 1036 for lhs, op_, rhs in zip(operands, expr.ops, operands[1:]): 1037 op = type(op_) 1038 op_token = ExprBuilder.cmpop_map.get(op) 1039 r = ctx.make_raw_range(lhs.range().end, rhs.range().start) 1040 if op_token is None: 1041 raise NotSupportedError( 1042 r, "unsupported comparison operator: " + op.__name__ 1043 ) 1044 1045 if op == ast.NotIn: 1046 # NB: `not in` is just `not( in )`, so we don't introduce new tree view 1047 # but just make it a nested call in our tree view structure 1048 in_expr = BinOp("in", lhs, rhs) 1049 cmp_expr = UnaryOp(r, "not", in_expr) 1050 else: 1051 cmp_expr = BinOp(op_token, lhs, rhs) 1052 1053 if result is None: 1054 result = cmp_expr 1055 else: 1056 result = BinOp("and", result, cmp_expr) 1057 return result 1058 1059 @staticmethod 1060 def build_Subscript(ctx, expr): 1061 def build_SliceExpr(ctx, base, slice_expr): 1062 lower = ( 1063 build_expr(ctx, slice_expr.lower) 1064 if slice_expr.lower is not None 1065 else None 1066 ) 1067 upper = ( 1068 build_expr(ctx, slice_expr.upper) 1069 if slice_expr.upper is not None 1070 else None 1071 ) 1072 step = ( 1073 build_expr(ctx, slice_expr.step) 1074 if slice_expr.step is not None 1075 else None 1076 ) 1077 return SliceExpr(base.range(), lower, upper, step) 1078 1079 def build_Index(ctx, base, index_expr): 1080 if isinstance(index_expr.value, ast.Tuple): 1081 raise NotSupportedError( 1082 base.range(), 1083 "slicing multiple dimensions with tuples not supported yet", 1084 ) 1085 return build_expr(ctx, index_expr.value) 1086 1087 def build_ExtSlice(ctx, base, extslice): 1088 sub_exprs = [] 1089 for expr in extslice.dims: 1090 sub_type = type(expr) 1091 if sub_type is ast.Index: 1092 sub_exprs.append(build_Index(ctx, base, expr)) 1093 elif sub_type is ast.Slice: 1094 sub_exprs.append(build_SliceExpr(ctx, base, expr)) 1095 elif sub_type is ast.Constant and expr.value is Ellipsis: 1096 sub_exprs.append(Dots(base.range())) 1097 else: 1098 raise NotSupportedError( 1099 base.range(), 1100 f"slicing multiple dimensions with {sub_type} not supported", 1101 ) 1102 return sub_exprs 1103 1104 base = build_expr(ctx, expr.value) 1105 sub_type = type(expr.slice) 1106 if sub_type is ast.Index: 1107 if isinstance(expr.slice.value, ast.Tuple): 1108 # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k] 1109 # XXX: Indexing using a list is **different**! It triggers advanced indexing. 1110 indices = [ 1111 build_expr(ctx, index_expr) for index_expr in expr.slice.value.elts 1112 ] 1113 if not indices: 1114 # `col_offset` is an int, but `end_col_offset` is 1115 # `Optional[int]`. The magic number is here to make 1116 # sure we can parse `()` on any machine 1117 r = ctx.make_range( 1118 expr.lineno, 1119 expr.slice.value.col_offset, 1120 expr.slice.value.col_offset + 2, 1121 ) 1122 tup = TupleLiteral(r, []) 1123 indices.append(tup) 1124 return Subscript(base, indices) 1125 else: 1126 return Subscript(base, [build_expr(ctx, expr.slice.value)]) 1127 elif sub_type is ast.Slice: 1128 return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)]) 1129 elif sub_type is ast.ExtSlice: 1130 return Subscript(base, build_ExtSlice(ctx, base, expr.slice)) 1131 elif sys.version_info >= ( 1132 3, 1133 9, 1134 ): # In Python3.9 array indicies are not wrapped in ast.Index 1135 if sub_type is ast.Tuple: 1136 # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k] 1137 indices = [] 1138 for index_expr in expr.slice.elts: 1139 if isinstance(index_expr, ast.Slice): 1140 indices.append(build_SliceExpr(ctx, base, index_expr)) 1141 else: 1142 indices.append(build_expr(ctx, index_expr)) 1143 # Special-case logic for `typing.Tuple[()]` 1144 if not indices: 1145 # See note above r.e. magic number 1146 r = ctx.make_range( 1147 expr.lineno, expr.slice.col_offset, expr.slice.col_offset + 2 1148 ) 1149 tup = TupleLiteral(r, []) 1150 indices.append(tup) 1151 return Subscript(base, indices) 1152 return Subscript(base, [build_expr(ctx, expr.slice)]) 1153 else: # Ellipsis (can only happen in Python 2) 1154 raise NotSupportedError(base.range(), "ellipsis is not supported") 1155 1156 @staticmethod 1157 def build_List(ctx, expr): 1158 return ListLiteral( 1159 ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1), 1160 [build_expr(ctx, e) for e in expr.elts], 1161 ) 1162 1163 @staticmethod 1164 def build_Tuple(ctx, expr): 1165 return TupleLiteral( 1166 ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1), 1167 [build_expr(ctx, e) for e in expr.elts], 1168 ) 1169 1170 @staticmethod 1171 def build_Dict(ctx, expr): 1172 range = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) 1173 if expr.keys and not expr.keys[0]: 1174 raise NotSupportedError( 1175 range, "Dict expansion (e.g. `{**dict}`) is not supported" 1176 ) 1177 return DictLiteral( 1178 range, 1179 [build_expr(ctx, e) for e in expr.keys], 1180 [build_expr(ctx, e) for e in expr.values], 1181 ) 1182 1183 @staticmethod 1184 def build_Num(ctx, expr): 1185 value = str(expr.value) 1186 r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(value)) 1187 return Const(r, value) 1188 1189 @staticmethod 1190 def build_Constant(ctx, expr): 1191 value = expr.value 1192 if value is None or isinstance(value, bool): 1193 # NB: this check has to happen before the int check because bool is 1194 # a subclass of int 1195 return ExprBuilder.build_NameConstant(ctx, expr) 1196 if isinstance(value, (int, float, complex)): 1197 return ExprBuilder.build_Num(ctx, expr) 1198 elif isinstance(value, str): 1199 return ExprBuilder.build_Str(ctx, expr) 1200 elif isinstance(value, type(Ellipsis)): 1201 return ExprBuilder.build_Ellipsis(ctx, expr) 1202 else: 1203 error_range = ctx.make_range( 1204 expr.lineno, expr.col_offset, expr.col_offset + len(str(value)) 1205 ) 1206 raise FrontendError(error_range, "Unknown Constant expression type") 1207 1208 @staticmethod 1209 def build_Str(ctx, expr): 1210 value = str(expr.value) 1211 r = ctx.make_range( 1212 expr.lineno, expr.col_offset, expr.col_offset + len(value) + 1 1213 ) 1214 return StringLiteral(r, value) 1215 1216 @staticmethod 1217 def build_JoinedStr(ctx, expr): 1218 s = "" 1219 args = [] 1220 for value in expr.values: 1221 r = ctx.make_range(value.lineno, value.col_offset, value.col_offset + 1) 1222 if isinstance(value, ast.FormattedValue): 1223 if value.conversion != -1: 1224 raise NotSupportedError(r, "Don't support conversion in JoinedStr") 1225 if value.format_spec is not None: 1226 raise NotSupportedError(r, "Don't support formatting in JoinedStr") 1227 s += "{}" 1228 args.append(build_expr(ctx, value.value)) 1229 elif isinstance(value, ast.Constant): 1230 s += value.value 1231 else: 1232 raise NotSupportedError(r, "Unsupported value in JoinedStr") 1233 1234 r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) 1235 return Apply(Select(StringLiteral(r, s), Ident(r, "format")), args, []) 1236 1237 @staticmethod 1238 def build_ListComp(ctx, stmt): 1239 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset) 1240 if len(stmt.generators) != 1: 1241 raise NotSupportedError(r, "Only a single generator is currently supported") 1242 1243 if len(stmt.generators[0].ifs) != 0: 1244 raise NotSupportedError(r, "Comprehension ifs are not supported yet") 1245 1246 elt_expr = build_expr(ctx, stmt.elt) 1247 target_expr = build_expr(ctx, stmt.generators[0].target) 1248 iter_expr = build_expr(ctx, stmt.generators[0].iter) 1249 1250 return ListComp(r, elt_expr, target_expr, iter_expr) 1251 1252 @staticmethod 1253 def build_GeneratorExp(ctx, stmt): 1254 # Convert Generator expression to ListComp 1255 return ExprBuilder.build_ListComp(ctx, stmt) 1256 1257 @staticmethod 1258 def build_DictComp(ctx, stmt): 1259 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset) 1260 if len(stmt.generators) != 1: 1261 raise NotSupportedError(r, "Only a single generator is currently supported") 1262 1263 if len(stmt.generators[0].ifs) != 0: 1264 raise NotSupportedError(r, "Comprehension ifs are not supported yet") 1265 1266 key_expr = build_expr(ctx, stmt.key) 1267 value_expr = build_expr(ctx, stmt.value) 1268 target_expr = build_expr(ctx, stmt.generators[0].target) 1269 iter_expr = build_expr(ctx, stmt.generators[0].iter) 1270 1271 return DictComp(r, key_expr, value_expr, target_expr, iter_expr) 1272 1273 @staticmethod 1274 def build_Starred(ctx, expr): 1275 r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1) 1276 return Starred(r, build_expr(ctx, expr.value)) 1277 1278 1279build_expr = ExprBuilder() 1280build_stmt = StmtBuilder() 1281build_withitem = WithItemBuilder() 1282 1283 1284def find_before(ctx, pos, substr, offsets=(0, 0)): 1285 new_pos = ctx.source[:pos].rindex(substr) 1286 return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1]) 1287