1# mypy: ignore-errors 2 3import collections 4import functools 5import inspect 6import operator 7import types 8from typing import Dict, List, Optional, TYPE_CHECKING 9 10import torch 11import torch.fx 12from torch._guards import Source 13 14from .. import polyfills, variables 15from ..bytecode_transformation import create_call_function, create_instruction 16from ..exc import raise_observed_exception, unimplemented 17from ..source import AttrSource 18from ..utils import ( 19 get_fake_value, 20 guard_if_dyn, 21 is_namedtuple, 22 istype, 23 iter_contains, 24 Lit, 25 namedtuple_fields, 26 odict_values, 27 set_example_value, 28) 29from .base import MutableLocal, VariableTracker 30from .constant import ConstantVariable 31from .functions import UserFunctionVariable, UserMethodVariable 32from .iter import IteratorVariable 33 34 35if TYPE_CHECKING: 36 from torch._dynamo.symbolic_convert import InstructionTranslator 37 38 39class BaseListVariable(VariableTracker): 40 @staticmethod 41 def cls_for_instance(obj): 42 if is_namedtuple(obj): 43 return functools.partial(NamedTupleVariable, tuple_cls=type(obj)) 44 return BaseListVariable.cls_for(type(obj)) 45 46 @staticmethod 47 def cls_for(obj): 48 return { 49 iter: ListIteratorVariable, 50 list: ListVariable, 51 slice: SliceVariable, 52 torch.Size: SizeVariable, 53 tuple: TupleVariable, 54 odict_values: ListVariable, 55 torch.nn.ParameterList: ListVariable, 56 torch.nn.ModuleList: ListVariable, 57 collections.deque: DequeVariable, 58 }[obj] 59 60 def __init__( 61 self, 62 items: List[VariableTracker], 63 **kwargs, 64 ) -> None: 65 super().__init__(**kwargs) 66 assert isinstance(items, list) 67 assert all(isinstance(x, VariableTracker) for x in items) 68 self.items: List[VariableTracker] = items 69 70 def _as_proxy(self): 71 return [x.as_proxy() for x in self.items] 72 73 def modified(self, items, **kwargs): 74 return type(self)(items, **kwargs) 75 76 @property 77 def value(self): 78 return self.as_python_constant() 79 80 def debug_repr_helper(self, prefix, suffix): 81 return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix 82 83 def as_python_constant(self): 84 return self.python_type()([x.as_python_constant() for x in self.items]) 85 86 def as_proxy(self): 87 assert self.python_type() is not SizeVariable 88 return self.python_type()(self._as_proxy()) 89 90 def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): 91 from .tensor import SymNodeVariable 92 93 if isinstance(arg, SymNodeVariable): 94 index = arg.sym_num 95 else: 96 index = arg.as_python_constant() 97 98 if isinstance(index, slice): 99 # Set source to None because slicing a list gives a new local 100 return self.clone( 101 items=self.items[index], 102 source=None, 103 mutable_local=MutableLocal() if self.mutable_local else None, 104 ) 105 else: 106 assert isinstance(index, (int, torch.SymInt)) 107 return self.items[index] 108 109 def unpack_var_sequence(self, tx): 110 return list(self.items) 111 112 def call_method( 113 self, 114 tx, 115 name, 116 args: List["VariableTracker"], 117 kwargs: Dict[str, "VariableTracker"], 118 ) -> "VariableTracker": 119 if name == "__getitem__": 120 from .tensor import TensorVariable 121 122 assert not kwargs and len(args) == 1 123 if isinstance(args[0], TensorVariable): 124 value = get_fake_value(args[0].as_proxy().node, tx) 125 if value.constant is not None and value.constant.numel() == 1: 126 value = variables.ConstantVariable.create(value.constant.item()) 127 else: 128 unimplemented("__getitem__ with non-constant tensor") 129 else: 130 value = args[0] 131 return self.getitem_const(tx, value) 132 elif name == "__contains__": 133 assert len(args) == 1 134 assert not kwargs 135 return iter_contains(self.unpack_var_sequence(tx), args[0], tx) 136 elif name == "index": 137 from .builder import SourcelessBuilder 138 139 return tx.inline_user_function_return( 140 SourcelessBuilder.create(tx, polyfills.index), 141 [self] + list(args), 142 kwargs, 143 ) 144 145 return super().call_method(tx, name, args, kwargs) 146 147 @staticmethod 148 def list_compare(tx: "InstructionTranslator", op, left, right): 149 return variables.UserFunctionVariable(polyfills.list_cmp).call_function( 150 tx, [variables.BuiltinVariable(op), left, right], {} 151 ) 152 153 154class RangeVariable(BaseListVariable): 155 def __init__(self, items, **kwargs) -> None: 156 items_to_map = items 157 start = variables.ConstantVariable.create(0) 158 stop = None 159 step = variables.ConstantVariable.create(1) 160 161 if len(items_to_map) == 1: 162 (stop,) = items_to_map 163 elif len(items_to_map) == 2: 164 start, stop = items_to_map 165 elif len(items_to_map) == 3: 166 start, stop, step = items_to_map 167 else: 168 raise AssertionError 169 170 assert stop is not None 171 super().__init__([start, stop, step], **kwargs) 172 173 def debug_repr(self): 174 return self.debug_repr_helper("range(", ")") 175 176 def python_type(self): 177 return range 178 179 def start(self): 180 return self.items[0].as_python_constant() 181 182 def stop(self): 183 return self.items[1].as_python_constant() 184 185 def step(self): 186 return self.items[2].as_python_constant() 187 188 def range_length(self): 189 lo = self.start() 190 hi = self.stop() 191 step = self.step() 192 193 assert step != 0 194 if step > 0 and lo < hi: 195 return 1 + (hi - 1 - lo) // step 196 elif step < 0 and lo > hi: 197 return 1 + (lo - 1 - hi) // (0 - step) 198 else: 199 return 0 200 201 def _get_slice_indices(self, length, slice): 202 step_is_negative = 0 203 204 if slice.step is None: 205 step = 1 206 step_is_negative = False 207 else: 208 step = slice.step 209 step_is_negative = slice.step < 0 210 211 # Find lower and upper bounds for start and stop. 212 if step_is_negative: 213 lower = -1 214 upper = length + lower 215 else: 216 lower = 0 217 upper = length 218 219 # Compute start 220 if slice.start is None: 221 start = upper if step_is_negative else lower 222 else: 223 start = slice.start 224 225 if start < 0: 226 start += length 227 if start < lower: 228 start = lower 229 else: 230 if start > upper: 231 start = upper 232 233 # Compute stop. 234 if slice.stop is None: 235 stop = lower if step_is_negative else upper 236 237 else: 238 stop = slice.stop 239 240 if stop < 0: 241 stop += length 242 if stop < lower: 243 stop = lower 244 else: 245 if stop > upper: 246 stop = upper 247 248 return [start, stop, step] 249 250 def apply_index(self, index): 251 length = self.range_length() 252 if index < 0: 253 index = length + index 254 255 if index < 0 or index >= length: 256 raise IndexError(f"index {index} is out of range") 257 258 return variables.ConstantVariable.create(self.start() + (index * self.step())) 259 260 def apply_slice(self, slice): 261 (slice_start, slice_stop, slice_step) = self._get_slice_indices( 262 self.range_length(), slice 263 ) 264 265 def compute_item(index): 266 return self.start() + (index * self.step()) 267 268 sub_step = self.step() * slice_step 269 sub_start = compute_item(slice_start) 270 sub_stop = compute_item(slice_stop) 271 272 result = RangeVariable( 273 [ 274 variables.ConstantVariable.create(x) 275 for x in [sub_start, sub_stop, sub_step] 276 ], 277 mutable_local=MutableLocal() if self.mutable_local else None, 278 ) 279 return result 280 281 def as_python_constant(self): 282 return range(*[x.as_python_constant() for x in self.items]) 283 284 def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): 285 # implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c 286 index = arg.as_python_constant() 287 288 if isinstance(index, slice): 289 return self.apply_slice(index) 290 else: 291 return self.apply_index(index) 292 293 def as_proxy(self): 294 return self.python_type()(*self._as_proxy()) 295 296 def unpack_var_sequence(self, tx=None): 297 return [variables.ConstantVariable.create(x) for x in self.as_python_constant()] 298 299 def reconstruct(self, codegen): 300 assert "range" not in codegen.tx.f_globals 301 codegen.add_push_null( 302 lambda: codegen.append_output(codegen.create_load_python_module(range)) 303 ) 304 codegen.foreach(self.items) 305 codegen.extend_output(create_call_function(3, False)) 306 307 def var_getattr(self, tx: "InstructionTranslator", name): 308 fields = ["start", "stop", "step"] 309 if name not in fields: 310 unimplemented(f"range.{name}") 311 return self.items[fields.index(name)] 312 313 314class CommonListMethodsVariable(BaseListVariable): 315 """ 316 Implement methods common to List and other List-like things 317 """ 318 319 def call_method( 320 self, 321 tx, 322 name, 323 args: List["VariableTracker"], 324 kwargs: Dict[str, "VariableTracker"], 325 ) -> "VariableTracker": 326 from .tensor import SymNodeVariable 327 328 if name == "append" and self.mutable_local: 329 assert not kwargs 330 (arg,) = args 331 tx.output.side_effects.mutation(self) 332 self.items.append(arg) 333 return ConstantVariable.create(None) 334 elif ( 335 name == "extend" 336 and self.mutable_local 337 and args 338 and args[0].has_force_unpack_var_sequence(tx) 339 ): 340 assert not kwargs 341 (arg,) = args 342 seq = arg.force_unpack_var_sequence(tx) 343 tx.output.side_effects.mutation(self) 344 self.items.extend(seq) 345 return ConstantVariable.create(None) 346 elif name == "insert" and self.mutable_local: 347 assert not kwargs 348 idx, value = args 349 if isinstance(idx, SymNodeVariable): 350 const_idx = idx.evaluate_expr() 351 else: 352 const_idx = idx.as_python_constant() 353 tx.output.side_effects.mutation(self) 354 self.items.insert(const_idx, value) 355 return ConstantVariable.create(None) 356 elif name == "pop" and self.mutable_local: 357 assert not kwargs 358 tx.output.side_effects.mutation(self) 359 return self.items.pop(*[a.as_python_constant() for a in args]) 360 elif name == "clear" and self.mutable_local: 361 assert not kwargs and not args 362 tx.output.side_effects.mutation(self) 363 self.items.clear() 364 return ConstantVariable.create(None) 365 elif ( 366 name == "__setitem__" 367 and self.mutable_local 368 and args 369 and args[0].is_python_constant() 370 ): 371 assert not kwargs 372 key, value = args 373 tx.output.side_effects.mutation(self) 374 if isinstance(key, SliceVariable): 375 self.items[key.as_python_constant()] = list(value.items) 376 else: 377 self.items[key.as_python_constant()] = value 378 return ConstantVariable.create(None) 379 elif name == "copy": 380 # List copy() doesn't have args and kwargs 381 assert not kwargs 382 assert not args 383 items = list(self.items) 384 return self.modified(items, mutable_local=MutableLocal()) 385 elif name == "reverse" and self.mutable_local: 386 assert not kwargs 387 assert not args 388 self.items.reverse() 389 tx.output.side_effects.mutation(self) 390 return ConstantVariable.create(None) 391 else: 392 return super().call_method(tx, name, args, kwargs) 393 394 395class ListVariable(CommonListMethodsVariable): 396 def python_type(self): 397 return list 398 399 def __repr__(self) -> str: 400 return f"{self.__class__.__name__}(length={len(self.items)})" 401 402 def debug_repr(self): 403 return self.debug_repr_helper("[", "]") 404 405 def reconstruct(self, codegen): 406 codegen.foreach(self.items) 407 codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.items))) 408 409 def call_method( 410 self, 411 tx, 412 name, 413 args: List["VariableTracker"], 414 kwargs: Dict[str, "VariableTracker"], 415 ) -> "VariableTracker": 416 if ( 417 name == "__setitem__" 418 and self.mutable_local 419 and args 420 and args[0].is_python_constant() 421 ): 422 assert not kwargs 423 key, value = args 424 tx.output.side_effects.mutation(self) 425 if isinstance(key, SliceVariable): 426 if not value.has_force_unpack_var_sequence(tx): 427 unimplemented( 428 f"Missing dynamo support for expanding {value} into a list for slice assignment." 429 ) 430 self.items[key.as_python_constant()] = value.force_unpack_var_sequence( 431 tx 432 ) 433 else: 434 self.items[key.as_python_constant()] = value 435 return ConstantVariable.create(None) 436 else: 437 return super().call_method(tx, name, args, kwargs) 438 439 def var_getattr(self, tx, name): 440 if name == "__class__": 441 source = AttrSource(self.source, name) if self.source else None 442 class_type = self.python_type() 443 if class_type is list: 444 return variables.BuiltinVariable(class_type, source=source) 445 else: 446 return variables.UserDefinedClassVariable(class_type, source=source) 447 return super().var_getattr(tx, name) 448 449 def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 450 if self.python_type() is not list: 451 return super().call_hasattr(tx, name) 452 return variables.ConstantVariable.create(hasattr([], name)) 453 454 455class DequeVariable(CommonListMethodsVariable): 456 def python_type(self): 457 return collections.deque 458 459 def debug_repr(self): 460 return self.debug_repr_helper("deque([", "])") 461 462 def reconstruct(self, codegen): 463 assert "deque" not in codegen.tx.f_globals 464 codegen.add_push_null( 465 lambda: codegen.append_output( 466 codegen.create_load_python_module(collections.deque) 467 ) 468 ) 469 codegen.foreach(self.items) 470 codegen.extend_output( 471 [ 472 create_instruction("BUILD_LIST", arg=len(self.items)), 473 *create_call_function(1, False), 474 ] 475 ) 476 477 def call_method( 478 self, 479 tx, 480 name, 481 args: List["VariableTracker"], 482 kwargs: Dict[str, "VariableTracker"], 483 ) -> "VariableTracker": 484 if ( 485 name == "__setitem__" 486 and self.mutable_local 487 and args 488 and args[0].is_python_constant() 489 ): 490 assert not kwargs 491 key, value = args 492 assert key.is_python_constant() and isinstance( 493 key.as_python_constant(), int 494 ) 495 tx.output.side_effects.mutation(self) 496 self.items[key.as_python_constant()] = value 497 return ConstantVariable.create(None) 498 elif ( 499 name == "extendleft" 500 and self.mutable_local 501 and args[0].has_force_unpack_var_sequence(tx) 502 ): 503 assert not kwargs 504 505 (arg,) = args 506 prefix = arg.force_unpack_var_sequence(tx) 507 prefix.reverse() 508 tx.output.side_effects.mutation(self) 509 self.items = prefix + list(self.items) 510 return ConstantVariable.create(None) 511 elif name == "popleft" and self.mutable_local: 512 assert not args 513 assert not kwargs 514 item = self.items[0] 515 tx.output.side_effects.mutation(self) 516 self.items = self.items[1:] 517 return item 518 elif name == "appendleft" and self.mutable_local: 519 assert not kwargs 520 tx.output.side_effects.mutation(self) 521 self.items = [args[0]] + list(self.items) 522 return ConstantVariable.create(None) 523 else: 524 return super().call_method(tx, name, args, kwargs) 525 526 527class TupleVariable(BaseListVariable): 528 def python_type(self): 529 return tuple 530 531 def __repr__(self) -> str: 532 return f"{self.__class__.__name__}(length={len(self.items)})" 533 534 def debug_repr(self): 535 return self.debug_repr_helper("(", ")") 536 537 def reconstruct(self, codegen): 538 codegen.foreach(self.items) 539 codegen.append_output(create_instruction("BUILD_TUPLE", arg=len(self.items))) 540 541 def call_method( 542 self, 543 tx, 544 name, 545 args: List["VariableTracker"], 546 kwargs: Dict[str, "VariableTracker"], 547 ) -> "VariableTracker": 548 return super().call_method(tx, name, args, kwargs) 549 550 def var_getattr(self, tx, name): 551 if name == "__class__": 552 source = AttrSource(self.source, name) if self.source else None 553 class_type = self.python_type() 554 if class_type is tuple: 555 return variables.BuiltinVariable(class_type, source=source) 556 else: 557 return variables.UserDefinedClassVariable(class_type, source=source) 558 return super().var_getattr(tx, name) 559 560 def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 561 if self.python_type() is not tuple: 562 return super().call_hasattr(tx, name) 563 return variables.ConstantVariable.create(hasattr((), name)) 564 565 566class SizeVariable(TupleVariable): 567 """torch.Size(...)""" 568 569 _nonvar_fields = { 570 "proxy", 571 *TupleVariable._nonvar_fields, 572 } 573 574 def __init__( 575 self, 576 items: List[VariableTracker], 577 proxy: Optional[torch.fx.Proxy] = None, 578 **kwargs, 579 ) -> None: 580 self.proxy = proxy 581 super().__init__(items, **kwargs) 582 583 def debug_repr(self): 584 return self.debug_repr_helper("torch.Size([", "])") 585 586 def python_type(self): 587 return torch.Size 588 589 def as_proxy(self): 590 if self.proxy is not None: 591 return self.proxy 592 593 # torch.Size needs special handling. Normally, we pun a list-like 594 # container to directly contain Proxy/Node objects from FX, and FX 595 # knows to look inside containers (via map_aggregate). But torch.Size 596 # is weird; although it subclasses from tuple, it doesn't allow 597 # members which aren't int-like (rejecting Proxy and Node). This 598 # means we can't use the normal representation trick 599 # torch.Size([proxy0, proxy1]). I looked into seeing if I could 600 # relax torch.Size in PyTorch proper, but if torch.Size constructor 601 # sees a type that it doesn't recognize, it will try to call 602 # __index__() on it, so there is no BC way to actually change this 603 # behavior (though it occurs to me that I could have just added a 604 # YOLO no checking alternate constructor.) 605 # 606 # To work around this problem, I represent a torch.Size proxy as 607 # a straight up proxy, that would have been constructed by taking 608 # the constituent proxies as arguments. This trick can be generally 609 # used for any construct that we need a proxy for but we can't 610 # directly represent as an aggregate; I don't see very many examples 611 # of this in torchdynamo though! 612 613 # Look for a proxy. If there are none, do the legacy behavior 614 tracer = None 615 proxies = self._as_proxy() 616 for proxy in proxies: 617 if isinstance(proxy, torch.fx.Proxy): 618 tracer = proxy.tracer 619 break 620 621 if tracer is None: 622 return torch.Size(proxies) 623 624 proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {}) 625 set_example_value( 626 proxy.node, 627 torch.Size( 628 [ 629 p.node.meta["example_value"] if not isinstance(p, int) else p 630 for p in proxies 631 ] 632 ), 633 ) 634 return proxy 635 636 def reconstruct(self, codegen): 637 codegen.add_push_null(lambda: codegen.load_import_from("torch", "Size")) 638 codegen.foreach(self.items) 639 build_torch_size = [ 640 create_instruction("BUILD_TUPLE", arg=len(self.items)), 641 ] + create_call_function(1, False) 642 codegen.extend_output(build_torch_size) 643 644 def unpack_var_sequence(self, tx): 645 return list(self.items) 646 647 def numel(self, tx): 648 from .builtin import BuiltinVariable 649 from .tensor import SymNodeVariable 650 651 const_result = 1 652 sym_sizes = [] 653 654 for v in self.items: 655 if isinstance(v, ConstantVariable): 656 const_result *= v.value 657 else: 658 assert isinstance(v, SymNodeVariable), type(v) 659 # Delay proxy calls until we know it will be necessary 660 sym_sizes.append(v) 661 662 result = ConstantVariable.create(const_result) 663 if sym_sizes and const_result == 1: 664 # Skip multiplying by 1 665 result, *sym_sizes = sym_sizes 666 667 if not sym_sizes or const_result == 0: 668 return result 669 670 mul = BuiltinVariable(operator.mul) 671 for v in sym_sizes: 672 result = mul.call_function(tx, [result, v], {}) 673 return result 674 675 def call_method( 676 self, 677 tx, 678 name, 679 args: List["VariableTracker"], 680 kwargs: Dict[str, "VariableTracker"], 681 ) -> "VariableTracker": 682 if name == "__getitem__": 683 assert not kwargs and len(args) == 1 684 out = self.get_item_dyn(tx, args[0]) 685 return out 686 elif name == "numel": 687 assert not args and not kwargs 688 return self.numel(tx) 689 690 return super().call_method(tx, name, args, kwargs) 691 692 def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker): 693 from .tensor import SymNodeVariable 694 695 if isinstance(arg, SymNodeVariable): 696 index = arg.sym_num 697 else: 698 index = arg.as_python_constant() 699 if isinstance(index, slice): 700 return SizeVariable(self.items[index]) 701 else: 702 assert isinstance(index, (int, torch.SymInt)) 703 return self.items[index] 704 705 def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 706 return variables.ConstantVariable.create(hasattr(torch.Size, name)) 707 708 709class NamedTupleVariable(TupleVariable): 710 _nonvar_fields = { 711 "tuple_cls", 712 *TupleVariable._nonvar_fields, 713 } 714 715 def __init__(self, items, tuple_cls, **kwargs) -> None: 716 super().__init__(items, **kwargs) 717 self.tuple_cls = tuple_cls 718 719 def debug_repr(self): 720 return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items))) 721 722 def python_type(self): 723 return self.tuple_cls 724 725 def as_python_constant(self): 726 return self.python_type()(*[x.as_python_constant() for x in self.items]) 727 728 def as_proxy(self): 729 assert self.python_type() is not SizeVariable 730 return self.python_type()(*self._as_proxy()) 731 732 def reconstruct(self, codegen): 733 create_fn = getattr(self.tuple_cls, "_make", self.tuple_cls) 734 codegen.add_push_null( 735 lambda: codegen.append_output(codegen._create_load_const(create_fn)) 736 ) 737 codegen.foreach(self.items) 738 codegen.extend_output( 739 [ 740 create_instruction("BUILD_TUPLE", arg=len(self.items)), 741 ] 742 + create_call_function(1, False) 743 ) 744 745 def var_getattr(self, tx: "InstructionTranslator", name): 746 def check_and_create_method(): 747 method = inspect.getattr_static(self.tuple_cls, name, None) 748 if isinstance(method, classmethod): 749 # We need the unbounded cls method to avoid the inline __self__ 750 return UserMethodVariable( 751 method.__func__, 752 variables.UserDefinedClassVariable(self.tuple_cls), 753 ) 754 elif isinstance(method, staticmethod): 755 return UserFunctionVariable(method.__func__) 756 elif inspect.isfunction(method): 757 return UserMethodVariable(method, self) 758 else: 759 return None 760 761 fields = namedtuple_fields(self.tuple_cls) 762 if name not in fields: 763 method = check_and_create_method() 764 if not method: 765 return super().var_getattr(tx, name) 766 return method 767 return self.items[fields.index(name)] 768 769 def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": 770 return variables.ConstantVariable.create(hasattr(self.tuple_cls, name)) 771 772 773class SliceVariable(BaseListVariable): 774 def __init__(self, items, **kwargs) -> None: 775 items_to_map = items 776 start, stop, step = [variables.ConstantVariable.create(None)] * 3 777 778 if len(items_to_map) == 1: 779 (stop,) = items_to_map 780 elif len(items_to_map) == 2: 781 start, stop = items_to_map 782 elif len(items_to_map) == 3: 783 start, stop, step = items_to_map 784 else: 785 raise AssertionError 786 787 if isinstance(start, variables.TensorVariable) or isinstance( 788 stop, variables.TensorVariable 789 ): 790 unimplemented("Dynamic slicing on data-dependent value is not supported") 791 792 super().__init__([start, stop, step], **kwargs) 793 794 def debug_repr(self): 795 return self.debug_repr_helper("slice(", ")") 796 797 def as_proxy(self): 798 return slice(*self._as_proxy()) 799 800 def python_type(self): 801 return slice 802 803 def as_python_constant(self): 804 return slice(*[guard_if_dyn(x) for x in self.items]) 805 806 def reconstruct(self, codegen): 807 codegen.foreach(self.items) 808 codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items))) 809 810 def var_getattr(self, tx: "InstructionTranslator", name): 811 fields = ["start", "stop", "step"] 812 if name not in fields: 813 unimplemented(f"slice.{name}") 814 return self.items[fields.index(name)] 815 816 817class ListIteratorVariable(IteratorVariable): 818 _nonvar_fields = { 819 "index", 820 *IteratorVariable._nonvar_fields, 821 } 822 823 def __init__(self, items, index: int = 0, **kwargs) -> None: 824 super().__init__(**kwargs) 825 assert isinstance(items, list) 826 # Removing this check as it slows things down too much 827 # https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492 828 829 # assert all(isinstance(x, VariableTracker) for x in items) 830 self.items = items 831 self.index = index 832 833 def __repr__(self) -> str: 834 return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})" 835 836 def next_variable(self, tx): 837 assert self.mutable_local 838 old_index = self.index 839 if old_index >= len(self.items): 840 raise_observed_exception(StopIteration, tx, self) 841 842 tx.output.side_effects.mutation(self) 843 self.index += 1 844 return self.items[old_index] 845 846 def call_method( 847 self, 848 tx, 849 name, 850 args: "List[VariableTracker]", 851 kwargs: "Dict[str, VariableTracker]", 852 ): 853 if name == "__contains__": 854 assert len(args) == 1 855 assert not kwargs 856 return iter_contains(self.items[self.index :], args[0], tx) 857 858 return super().call_method(tx, name, args, kwargs) 859 860 def python_type(self): 861 return type(iter([])) 862 863 def as_python_constant(self): 864 if self.index > 0: 865 raise NotImplementedError 866 return iter([x.as_python_constant() for x in self.items]) 867 868 def unpack_var_sequence(self, tx): 869 return list(self.items[self.index :]) 870 871 def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: 872 return self.unpack_var_sequence(tx) 873 874 def reconstruct(self, codegen): 875 remaining_items = self.items[self.index :] 876 codegen.foreach(remaining_items) 877 codegen.extend_output( 878 [ 879 create_instruction("BUILD_TUPLE", arg=len(remaining_items)), 880 create_instruction("GET_ITER"), 881 ] 882 ) 883 884 885class TupleIteratorVariable(ListIteratorVariable): 886 pass 887 888 889class RestrictedListSubclassVariable(ListVariable): 890 """ 891 This is a special case of UserDefinedObjectVariable where: 892 1) The user subclasses list 893 2) None of the list methods are overriden, merely some new methods are added 894 895 In these cases, we can prevent graph breaks by not using the general 896 UserDefinedObjectVariable machinery and instead treating it like 897 a ListVariable. 898 """ 899 900 _nonvar_fields = {"user_cls", "user_cls_source", *ListVariable._nonvar_fields} 901 _allowed_names = { 902 "__call__", 903 "__module__", 904 "__dict__", 905 "__doc__", 906 "__name__", 907 "__qualname__", 908 } 909 _disallowed_names = { 910 "__getattribute__", 911 "__getattr__", 912 "__setattr__", 913 } 914 915 @classmethod 916 def _is_non_conflicting_subclass( 917 cls, 918 user_cls: type, 919 python_cls: type, 920 ): 921 """Ensures user_cls inherits from python_cls (e.g. list) and does not override any methods on python_cls""" 922 if ( 923 not istype(user_cls, type) 924 or user_cls.__bases__ != (python_cls,) 925 or user_cls.__mro__ != (user_cls, python_cls, object) 926 ): 927 return False # not subclass 928 return not any( 929 hasattr(python_cls, name) or name in cls._disallowed_names 930 for name in set(user_cls.__dict__.keys()) - cls._allowed_names 931 ) 932 933 @classmethod 934 def is_matching_cls(cls, user_cls: type): 935 return cls._is_non_conflicting_subclass(user_cls, list) 936 937 def __init__( 938 self, items, *, user_cls: type, user_cls_source: Source, **kwargs 939 ) -> None: 940 super().__init__(items=items, **kwargs) 941 self.user_cls = user_cls 942 self.user_cls_source = user_cls_source 943 assert istype(user_cls, type) 944 assert isinstance(user_cls_source, Source) 945 946 def debug_repr(self): 947 # The constructor is safe as no methods, including __init__, are 948 # allowed to be overridden 949 # NB: This is guaranteed to print like a list, as __repr__ cannot be 950 # overridden, this is... well, it's OK I guess (consistent with 951 # eager), but it could be misleading. You will have to query type 952 # instead for details. 953 return repr(self.user_cls([Lit(x.debug_repr()) for x in self.items])) 954 955 def python_type(self): 956 return self.user_cls 957 958 def as_proxy(self): 959 return [x.as_proxy() for x in self.items] 960 961 def as_python_constant(self): 962 raise NotImplementedError 963 964 def is_python_constant(self): 965 return False 966 967 @property 968 def value(self): 969 raise AttributeError("value") 970 971 def modified(self, items, **kwargs): 972 return type(self)( 973 items, 974 user_cls=self.user_cls, 975 user_cls_source=self.user_cls_source, 976 **kwargs, 977 ) 978 979 def reconstruct(self, codegen): 980 codegen.add_push_null(lambda: codegen(self.user_cls_source)) 981 super().reconstruct(codegen) 982 codegen.extend_output(create_call_function(1, False)) 983 984 def call_method( 985 self, 986 tx, 987 name, 988 args: List["VariableTracker"], 989 kwargs: Dict[str, "VariableTracker"], 990 ) -> "VariableTracker": 991 if name in self.user_cls.__dict__: 992 method = self.user_cls.__dict__[name] 993 if isinstance(method, types.FunctionType): 994 # inline the method 995 source = AttrSource(self.user_cls_source, name) 996 return UserMethodVariable(method, self, source=source).call_function( 997 tx, args, kwargs 998 ) 999 unimplemented( 1000 f"RestrictedListSubclassVariable method {self.user_cls.__name__}.{name}" 1001 ) 1002 return super().call_method(tx, name, args, kwargs) 1003 1004 def call_function( 1005 self, 1006 tx: "InstructionTranslator", 1007 args: "List[VariableTracker]", 1008 kwargs: "Dict[str, VariableTracker]", 1009 ) -> "VariableTracker": 1010 return self.call_method(tx, "__call__", args, kwargs) 1011