xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/lists.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import collections
4import functools
5import inspect
6import operator
7import types
8from typing import Dict, List, Optional, TYPE_CHECKING
9
10import torch
11import torch.fx
12from torch._guards import Source
13
14from .. import polyfills, variables
15from ..bytecode_transformation import create_call_function, create_instruction
16from ..exc import raise_observed_exception, unimplemented
17from ..source import AttrSource
18from ..utils import (
19    get_fake_value,
20    guard_if_dyn,
21    is_namedtuple,
22    istype,
23    iter_contains,
24    Lit,
25    namedtuple_fields,
26    odict_values,
27    set_example_value,
28)
29from .base import MutableLocal, VariableTracker
30from .constant import ConstantVariable
31from .functions import UserFunctionVariable, UserMethodVariable
32from .iter import IteratorVariable
33
34
35if TYPE_CHECKING:
36    from torch._dynamo.symbolic_convert import InstructionTranslator
37
38
39class BaseListVariable(VariableTracker):
40    @staticmethod
41    def cls_for_instance(obj):
42        if is_namedtuple(obj):
43            return functools.partial(NamedTupleVariable, tuple_cls=type(obj))
44        return BaseListVariable.cls_for(type(obj))
45
46    @staticmethod
47    def cls_for(obj):
48        return {
49            iter: ListIteratorVariable,
50            list: ListVariable,
51            slice: SliceVariable,
52            torch.Size: SizeVariable,
53            tuple: TupleVariable,
54            odict_values: ListVariable,
55            torch.nn.ParameterList: ListVariable,
56            torch.nn.ModuleList: ListVariable,
57            collections.deque: DequeVariable,
58        }[obj]
59
60    def __init__(
61        self,
62        items: List[VariableTracker],
63        **kwargs,
64    ) -> None:
65        super().__init__(**kwargs)
66        assert isinstance(items, list)
67        assert all(isinstance(x, VariableTracker) for x in items)
68        self.items: List[VariableTracker] = items
69
70    def _as_proxy(self):
71        return [x.as_proxy() for x in self.items]
72
73    def modified(self, items, **kwargs):
74        return type(self)(items, **kwargs)
75
76    @property
77    def value(self):
78        return self.as_python_constant()
79
80    def debug_repr_helper(self, prefix, suffix):
81        return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix
82
83    def as_python_constant(self):
84        return self.python_type()([x.as_python_constant() for x in self.items])
85
86    def as_proxy(self):
87        assert self.python_type() is not SizeVariable
88        return self.python_type()(self._as_proxy())
89
90    def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
91        from .tensor import SymNodeVariable
92
93        if isinstance(arg, SymNodeVariable):
94            index = arg.sym_num
95        else:
96            index = arg.as_python_constant()
97
98        if isinstance(index, slice):
99            # Set source to None because slicing a list gives a new local
100            return self.clone(
101                items=self.items[index],
102                source=None,
103                mutable_local=MutableLocal() if self.mutable_local else None,
104            )
105        else:
106            assert isinstance(index, (int, torch.SymInt))
107            return self.items[index]
108
109    def unpack_var_sequence(self, tx):
110        return list(self.items)
111
112    def call_method(
113        self,
114        tx,
115        name,
116        args: List["VariableTracker"],
117        kwargs: Dict[str, "VariableTracker"],
118    ) -> "VariableTracker":
119        if name == "__getitem__":
120            from .tensor import TensorVariable
121
122            assert not kwargs and len(args) == 1
123            if isinstance(args[0], TensorVariable):
124                value = get_fake_value(args[0].as_proxy().node, tx)
125                if value.constant is not None and value.constant.numel() == 1:
126                    value = variables.ConstantVariable.create(value.constant.item())
127                else:
128                    unimplemented("__getitem__ with non-constant tensor")
129            else:
130                value = args[0]
131            return self.getitem_const(tx, value)
132        elif name == "__contains__":
133            assert len(args) == 1
134            assert not kwargs
135            return iter_contains(self.unpack_var_sequence(tx), args[0], tx)
136        elif name == "index":
137            from .builder import SourcelessBuilder
138
139            return tx.inline_user_function_return(
140                SourcelessBuilder.create(tx, polyfills.index),
141                [self] + list(args),
142                kwargs,
143            )
144
145        return super().call_method(tx, name, args, kwargs)
146
147    @staticmethod
148    def list_compare(tx: "InstructionTranslator", op, left, right):
149        return variables.UserFunctionVariable(polyfills.list_cmp).call_function(
150            tx, [variables.BuiltinVariable(op), left, right], {}
151        )
152
153
154class RangeVariable(BaseListVariable):
155    def __init__(self, items, **kwargs) -> None:
156        items_to_map = items
157        start = variables.ConstantVariable.create(0)
158        stop = None
159        step = variables.ConstantVariable.create(1)
160
161        if len(items_to_map) == 1:
162            (stop,) = items_to_map
163        elif len(items_to_map) == 2:
164            start, stop = items_to_map
165        elif len(items_to_map) == 3:
166            start, stop, step = items_to_map
167        else:
168            raise AssertionError
169
170        assert stop is not None
171        super().__init__([start, stop, step], **kwargs)
172
173    def debug_repr(self):
174        return self.debug_repr_helper("range(", ")")
175
176    def python_type(self):
177        return range
178
179    def start(self):
180        return self.items[0].as_python_constant()
181
182    def stop(self):
183        return self.items[1].as_python_constant()
184
185    def step(self):
186        return self.items[2].as_python_constant()
187
188    def range_length(self):
189        lo = self.start()
190        hi = self.stop()
191        step = self.step()
192
193        assert step != 0
194        if step > 0 and lo < hi:
195            return 1 + (hi - 1 - lo) // step
196        elif step < 0 and lo > hi:
197            return 1 + (lo - 1 - hi) // (0 - step)
198        else:
199            return 0
200
201    def _get_slice_indices(self, length, slice):
202        step_is_negative = 0
203
204        if slice.step is None:
205            step = 1
206            step_is_negative = False
207        else:
208            step = slice.step
209            step_is_negative = slice.step < 0
210
211        # Find lower and upper bounds for start and stop.
212        if step_is_negative:
213            lower = -1
214            upper = length + lower
215        else:
216            lower = 0
217            upper = length
218
219        # Compute start
220        if slice.start is None:
221            start = upper if step_is_negative else lower
222        else:
223            start = slice.start
224
225        if start < 0:
226            start += length
227            if start < lower:
228                start = lower
229        else:
230            if start > upper:
231                start = upper
232
233        # Compute stop.
234        if slice.stop is None:
235            stop = lower if step_is_negative else upper
236
237        else:
238            stop = slice.stop
239
240            if stop < 0:
241                stop += length
242                if stop < lower:
243                    stop = lower
244            else:
245                if stop > upper:
246                    stop = upper
247
248        return [start, stop, step]
249
250    def apply_index(self, index):
251        length = self.range_length()
252        if index < 0:
253            index = length + index
254
255        if index < 0 or index >= length:
256            raise IndexError(f"index {index} is out of range")
257
258        return variables.ConstantVariable.create(self.start() + (index * self.step()))
259
260    def apply_slice(self, slice):
261        (slice_start, slice_stop, slice_step) = self._get_slice_indices(
262            self.range_length(), slice
263        )
264
265        def compute_item(index):
266            return self.start() + (index * self.step())
267
268        sub_step = self.step() * slice_step
269        sub_start = compute_item(slice_start)
270        sub_stop = compute_item(slice_stop)
271
272        result = RangeVariable(
273            [
274                variables.ConstantVariable.create(x)
275                for x in [sub_start, sub_stop, sub_step]
276            ],
277            mutable_local=MutableLocal() if self.mutable_local else None,
278        )
279        return result
280
281    def as_python_constant(self):
282        return range(*[x.as_python_constant() for x in self.items])
283
284    def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
285        # implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c
286        index = arg.as_python_constant()
287
288        if isinstance(index, slice):
289            return self.apply_slice(index)
290        else:
291            return self.apply_index(index)
292
293    def as_proxy(self):
294        return self.python_type()(*self._as_proxy())
295
296    def unpack_var_sequence(self, tx=None):
297        return [variables.ConstantVariable.create(x) for x in self.as_python_constant()]
298
299    def reconstruct(self, codegen):
300        assert "range" not in codegen.tx.f_globals
301        codegen.add_push_null(
302            lambda: codegen.append_output(codegen.create_load_python_module(range))
303        )
304        codegen.foreach(self.items)
305        codegen.extend_output(create_call_function(3, False))
306
307    def var_getattr(self, tx: "InstructionTranslator", name):
308        fields = ["start", "stop", "step"]
309        if name not in fields:
310            unimplemented(f"range.{name}")
311        return self.items[fields.index(name)]
312
313
314class CommonListMethodsVariable(BaseListVariable):
315    """
316    Implement methods common to List and other List-like things
317    """
318
319    def call_method(
320        self,
321        tx,
322        name,
323        args: List["VariableTracker"],
324        kwargs: Dict[str, "VariableTracker"],
325    ) -> "VariableTracker":
326        from .tensor import SymNodeVariable
327
328        if name == "append" and self.mutable_local:
329            assert not kwargs
330            (arg,) = args
331            tx.output.side_effects.mutation(self)
332            self.items.append(arg)
333            return ConstantVariable.create(None)
334        elif (
335            name == "extend"
336            and self.mutable_local
337            and args
338            and args[0].has_force_unpack_var_sequence(tx)
339        ):
340            assert not kwargs
341            (arg,) = args
342            seq = arg.force_unpack_var_sequence(tx)
343            tx.output.side_effects.mutation(self)
344            self.items.extend(seq)
345            return ConstantVariable.create(None)
346        elif name == "insert" and self.mutable_local:
347            assert not kwargs
348            idx, value = args
349            if isinstance(idx, SymNodeVariable):
350                const_idx = idx.evaluate_expr()
351            else:
352                const_idx = idx.as_python_constant()
353            tx.output.side_effects.mutation(self)
354            self.items.insert(const_idx, value)
355            return ConstantVariable.create(None)
356        elif name == "pop" and self.mutable_local:
357            assert not kwargs
358            tx.output.side_effects.mutation(self)
359            return self.items.pop(*[a.as_python_constant() for a in args])
360        elif name == "clear" and self.mutable_local:
361            assert not kwargs and not args
362            tx.output.side_effects.mutation(self)
363            self.items.clear()
364            return ConstantVariable.create(None)
365        elif (
366            name == "__setitem__"
367            and self.mutable_local
368            and args
369            and args[0].is_python_constant()
370        ):
371            assert not kwargs
372            key, value = args
373            tx.output.side_effects.mutation(self)
374            if isinstance(key, SliceVariable):
375                self.items[key.as_python_constant()] = list(value.items)
376            else:
377                self.items[key.as_python_constant()] = value
378            return ConstantVariable.create(None)
379        elif name == "copy":
380            # List copy() doesn't have args and kwargs
381            assert not kwargs
382            assert not args
383            items = list(self.items)
384            return self.modified(items, mutable_local=MutableLocal())
385        elif name == "reverse" and self.mutable_local:
386            assert not kwargs
387            assert not args
388            self.items.reverse()
389            tx.output.side_effects.mutation(self)
390            return ConstantVariable.create(None)
391        else:
392            return super().call_method(tx, name, args, kwargs)
393
394
395class ListVariable(CommonListMethodsVariable):
396    def python_type(self):
397        return list
398
399    def __repr__(self) -> str:
400        return f"{self.__class__.__name__}(length={len(self.items)})"
401
402    def debug_repr(self):
403        return self.debug_repr_helper("[", "]")
404
405    def reconstruct(self, codegen):
406        codegen.foreach(self.items)
407        codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.items)))
408
409    def call_method(
410        self,
411        tx,
412        name,
413        args: List["VariableTracker"],
414        kwargs: Dict[str, "VariableTracker"],
415    ) -> "VariableTracker":
416        if (
417            name == "__setitem__"
418            and self.mutable_local
419            and args
420            and args[0].is_python_constant()
421        ):
422            assert not kwargs
423            key, value = args
424            tx.output.side_effects.mutation(self)
425            if isinstance(key, SliceVariable):
426                if not value.has_force_unpack_var_sequence(tx):
427                    unimplemented(
428                        f"Missing dynamo support for expanding {value} into a list for slice assignment."
429                    )
430                self.items[key.as_python_constant()] = value.force_unpack_var_sequence(
431                    tx
432                )
433            else:
434                self.items[key.as_python_constant()] = value
435            return ConstantVariable.create(None)
436        else:
437            return super().call_method(tx, name, args, kwargs)
438
439    def var_getattr(self, tx, name):
440        if name == "__class__":
441            source = AttrSource(self.source, name) if self.source else None
442            class_type = self.python_type()
443            if class_type is list:
444                return variables.BuiltinVariable(class_type, source=source)
445            else:
446                return variables.UserDefinedClassVariable(class_type, source=source)
447        return super().var_getattr(tx, name)
448
449    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
450        if self.python_type() is not list:
451            return super().call_hasattr(tx, name)
452        return variables.ConstantVariable.create(hasattr([], name))
453
454
455class DequeVariable(CommonListMethodsVariable):
456    def python_type(self):
457        return collections.deque
458
459    def debug_repr(self):
460        return self.debug_repr_helper("deque([", "])")
461
462    def reconstruct(self, codegen):
463        assert "deque" not in codegen.tx.f_globals
464        codegen.add_push_null(
465            lambda: codegen.append_output(
466                codegen.create_load_python_module(collections.deque)
467            )
468        )
469        codegen.foreach(self.items)
470        codegen.extend_output(
471            [
472                create_instruction("BUILD_LIST", arg=len(self.items)),
473                *create_call_function(1, False),
474            ]
475        )
476
477    def call_method(
478        self,
479        tx,
480        name,
481        args: List["VariableTracker"],
482        kwargs: Dict[str, "VariableTracker"],
483    ) -> "VariableTracker":
484        if (
485            name == "__setitem__"
486            and self.mutable_local
487            and args
488            and args[0].is_python_constant()
489        ):
490            assert not kwargs
491            key, value = args
492            assert key.is_python_constant() and isinstance(
493                key.as_python_constant(), int
494            )
495            tx.output.side_effects.mutation(self)
496            self.items[key.as_python_constant()] = value
497            return ConstantVariable.create(None)
498        elif (
499            name == "extendleft"
500            and self.mutable_local
501            and args[0].has_force_unpack_var_sequence(tx)
502        ):
503            assert not kwargs
504
505            (arg,) = args
506            prefix = arg.force_unpack_var_sequence(tx)
507            prefix.reverse()
508            tx.output.side_effects.mutation(self)
509            self.items = prefix + list(self.items)
510            return ConstantVariable.create(None)
511        elif name == "popleft" and self.mutable_local:
512            assert not args
513            assert not kwargs
514            item = self.items[0]
515            tx.output.side_effects.mutation(self)
516            self.items = self.items[1:]
517            return item
518        elif name == "appendleft" and self.mutable_local:
519            assert not kwargs
520            tx.output.side_effects.mutation(self)
521            self.items = [args[0]] + list(self.items)
522            return ConstantVariable.create(None)
523        else:
524            return super().call_method(tx, name, args, kwargs)
525
526
527class TupleVariable(BaseListVariable):
528    def python_type(self):
529        return tuple
530
531    def __repr__(self) -> str:
532        return f"{self.__class__.__name__}(length={len(self.items)})"
533
534    def debug_repr(self):
535        return self.debug_repr_helper("(", ")")
536
537    def reconstruct(self, codegen):
538        codegen.foreach(self.items)
539        codegen.append_output(create_instruction("BUILD_TUPLE", arg=len(self.items)))
540
541    def call_method(
542        self,
543        tx,
544        name,
545        args: List["VariableTracker"],
546        kwargs: Dict[str, "VariableTracker"],
547    ) -> "VariableTracker":
548        return super().call_method(tx, name, args, kwargs)
549
550    def var_getattr(self, tx, name):
551        if name == "__class__":
552            source = AttrSource(self.source, name) if self.source else None
553            class_type = self.python_type()
554            if class_type is tuple:
555                return variables.BuiltinVariable(class_type, source=source)
556            else:
557                return variables.UserDefinedClassVariable(class_type, source=source)
558        return super().var_getattr(tx, name)
559
560    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
561        if self.python_type() is not tuple:
562            return super().call_hasattr(tx, name)
563        return variables.ConstantVariable.create(hasattr((), name))
564
565
566class SizeVariable(TupleVariable):
567    """torch.Size(...)"""
568
569    _nonvar_fields = {
570        "proxy",
571        *TupleVariable._nonvar_fields,
572    }
573
574    def __init__(
575        self,
576        items: List[VariableTracker],
577        proxy: Optional[torch.fx.Proxy] = None,
578        **kwargs,
579    ) -> None:
580        self.proxy = proxy
581        super().__init__(items, **kwargs)
582
583    def debug_repr(self):
584        return self.debug_repr_helper("torch.Size([", "])")
585
586    def python_type(self):
587        return torch.Size
588
589    def as_proxy(self):
590        if self.proxy is not None:
591            return self.proxy
592
593        # torch.Size needs special handling.  Normally, we pun a list-like
594        # container to directly contain Proxy/Node objects from FX, and FX
595        # knows to look inside containers (via map_aggregate).  But torch.Size
596        # is weird; although it subclasses from tuple, it doesn't allow
597        # members which aren't int-like (rejecting Proxy and Node).  This
598        # means we can't use the normal representation trick
599        # torch.Size([proxy0, proxy1]).  I looked into seeing if I could
600        # relax torch.Size in PyTorch proper, but if torch.Size constructor
601        # sees a type that it doesn't recognize, it will try to call
602        # __index__() on it, so there is no BC way to actually change this
603        # behavior (though it occurs to me that I could have just added a
604        # YOLO no checking alternate constructor.)
605        #
606        # To work around this problem, I represent a torch.Size proxy as
607        # a straight up proxy, that would have been constructed by taking
608        # the constituent proxies as arguments.  This trick can be generally
609        # used for any construct that we need a proxy for but we can't
610        # directly represent as an aggregate; I don't see very many examples
611        # of this in torchdynamo though!
612
613        # Look for a proxy.  If there are none, do the legacy behavior
614        tracer = None
615        proxies = self._as_proxy()
616        for proxy in proxies:
617            if isinstance(proxy, torch.fx.Proxy):
618                tracer = proxy.tracer
619                break
620
621        if tracer is None:
622            return torch.Size(proxies)
623
624        proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {})
625        set_example_value(
626            proxy.node,
627            torch.Size(
628                [
629                    p.node.meta["example_value"] if not isinstance(p, int) else p
630                    for p in proxies
631                ]
632            ),
633        )
634        return proxy
635
636    def reconstruct(self, codegen):
637        codegen.add_push_null(lambda: codegen.load_import_from("torch", "Size"))
638        codegen.foreach(self.items)
639        build_torch_size = [
640            create_instruction("BUILD_TUPLE", arg=len(self.items)),
641        ] + create_call_function(1, False)
642        codegen.extend_output(build_torch_size)
643
644    def unpack_var_sequence(self, tx):
645        return list(self.items)
646
647    def numel(self, tx):
648        from .builtin import BuiltinVariable
649        from .tensor import SymNodeVariable
650
651        const_result = 1
652        sym_sizes = []
653
654        for v in self.items:
655            if isinstance(v, ConstantVariable):
656                const_result *= v.value
657            else:
658                assert isinstance(v, SymNodeVariable), type(v)
659                # Delay proxy calls  until we know it will be necessary
660                sym_sizes.append(v)
661
662        result = ConstantVariable.create(const_result)
663        if sym_sizes and const_result == 1:
664            # Skip multiplying by 1
665            result, *sym_sizes = sym_sizes
666
667        if not sym_sizes or const_result == 0:
668            return result
669
670        mul = BuiltinVariable(operator.mul)
671        for v in sym_sizes:
672            result = mul.call_function(tx, [result, v], {})
673        return result
674
675    def call_method(
676        self,
677        tx,
678        name,
679        args: List["VariableTracker"],
680        kwargs: Dict[str, "VariableTracker"],
681    ) -> "VariableTracker":
682        if name == "__getitem__":
683            assert not kwargs and len(args) == 1
684            out = self.get_item_dyn(tx, args[0])
685            return out
686        elif name == "numel":
687            assert not args and not kwargs
688            return self.numel(tx)
689
690        return super().call_method(tx, name, args, kwargs)
691
692    def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker):
693        from .tensor import SymNodeVariable
694
695        if isinstance(arg, SymNodeVariable):
696            index = arg.sym_num
697        else:
698            index = arg.as_python_constant()
699        if isinstance(index, slice):
700            return SizeVariable(self.items[index])
701        else:
702            assert isinstance(index, (int, torch.SymInt))
703            return self.items[index]
704
705    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
706        return variables.ConstantVariable.create(hasattr(torch.Size, name))
707
708
709class NamedTupleVariable(TupleVariable):
710    _nonvar_fields = {
711        "tuple_cls",
712        *TupleVariable._nonvar_fields,
713    }
714
715    def __init__(self, items, tuple_cls, **kwargs) -> None:
716        super().__init__(items, **kwargs)
717        self.tuple_cls = tuple_cls
718
719    def debug_repr(self):
720        return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items)))
721
722    def python_type(self):
723        return self.tuple_cls
724
725    def as_python_constant(self):
726        return self.python_type()(*[x.as_python_constant() for x in self.items])
727
728    def as_proxy(self):
729        assert self.python_type() is not SizeVariable
730        return self.python_type()(*self._as_proxy())
731
732    def reconstruct(self, codegen):
733        create_fn = getattr(self.tuple_cls, "_make", self.tuple_cls)
734        codegen.add_push_null(
735            lambda: codegen.append_output(codegen._create_load_const(create_fn))
736        )
737        codegen.foreach(self.items)
738        codegen.extend_output(
739            [
740                create_instruction("BUILD_TUPLE", arg=len(self.items)),
741            ]
742            + create_call_function(1, False)
743        )
744
745    def var_getattr(self, tx: "InstructionTranslator", name):
746        def check_and_create_method():
747            method = inspect.getattr_static(self.tuple_cls, name, None)
748            if isinstance(method, classmethod):
749                # We need the unbounded cls method to avoid the inline __self__
750                return UserMethodVariable(
751                    method.__func__,
752                    variables.UserDefinedClassVariable(self.tuple_cls),
753                )
754            elif isinstance(method, staticmethod):
755                return UserFunctionVariable(method.__func__)
756            elif inspect.isfunction(method):
757                return UserMethodVariable(method, self)
758            else:
759                return None
760
761        fields = namedtuple_fields(self.tuple_cls)
762        if name not in fields:
763            method = check_and_create_method()
764            if not method:
765                return super().var_getattr(tx, name)
766            return method
767        return self.items[fields.index(name)]
768
769    def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
770        return variables.ConstantVariable.create(hasattr(self.tuple_cls, name))
771
772
773class SliceVariable(BaseListVariable):
774    def __init__(self, items, **kwargs) -> None:
775        items_to_map = items
776        start, stop, step = [variables.ConstantVariable.create(None)] * 3
777
778        if len(items_to_map) == 1:
779            (stop,) = items_to_map
780        elif len(items_to_map) == 2:
781            start, stop = items_to_map
782        elif len(items_to_map) == 3:
783            start, stop, step = items_to_map
784        else:
785            raise AssertionError
786
787        if isinstance(start, variables.TensorVariable) or isinstance(
788            stop, variables.TensorVariable
789        ):
790            unimplemented("Dynamic slicing on data-dependent value is not supported")
791
792        super().__init__([start, stop, step], **kwargs)
793
794    def debug_repr(self):
795        return self.debug_repr_helper("slice(", ")")
796
797    def as_proxy(self):
798        return slice(*self._as_proxy())
799
800    def python_type(self):
801        return slice
802
803    def as_python_constant(self):
804        return slice(*[guard_if_dyn(x) for x in self.items])
805
806    def reconstruct(self, codegen):
807        codegen.foreach(self.items)
808        codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items)))
809
810    def var_getattr(self, tx: "InstructionTranslator", name):
811        fields = ["start", "stop", "step"]
812        if name not in fields:
813            unimplemented(f"slice.{name}")
814        return self.items[fields.index(name)]
815
816
817class ListIteratorVariable(IteratorVariable):
818    _nonvar_fields = {
819        "index",
820        *IteratorVariable._nonvar_fields,
821    }
822
823    def __init__(self, items, index: int = 0, **kwargs) -> None:
824        super().__init__(**kwargs)
825        assert isinstance(items, list)
826        # Removing this check as it slows things down too much
827        # https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492
828
829        # assert all(isinstance(x, VariableTracker) for x in items)
830        self.items = items
831        self.index = index
832
833    def __repr__(self) -> str:
834        return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
835
836    def next_variable(self, tx):
837        assert self.mutable_local
838        old_index = self.index
839        if old_index >= len(self.items):
840            raise_observed_exception(StopIteration, tx, self)
841
842        tx.output.side_effects.mutation(self)
843        self.index += 1
844        return self.items[old_index]
845
846    def call_method(
847        self,
848        tx,
849        name,
850        args: "List[VariableTracker]",
851        kwargs: "Dict[str, VariableTracker]",
852    ):
853        if name == "__contains__":
854            assert len(args) == 1
855            assert not kwargs
856            return iter_contains(self.items[self.index :], args[0], tx)
857
858        return super().call_method(tx, name, args, kwargs)
859
860    def python_type(self):
861        return type(iter([]))
862
863    def as_python_constant(self):
864        if self.index > 0:
865            raise NotImplementedError
866        return iter([x.as_python_constant() for x in self.items])
867
868    def unpack_var_sequence(self, tx):
869        return list(self.items[self.index :])
870
871    def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
872        return self.unpack_var_sequence(tx)
873
874    def reconstruct(self, codegen):
875        remaining_items = self.items[self.index :]
876        codegen.foreach(remaining_items)
877        codegen.extend_output(
878            [
879                create_instruction("BUILD_TUPLE", arg=len(remaining_items)),
880                create_instruction("GET_ITER"),
881            ]
882        )
883
884
885class TupleIteratorVariable(ListIteratorVariable):
886    pass
887
888
889class RestrictedListSubclassVariable(ListVariable):
890    """
891    This is a special case of UserDefinedObjectVariable where:
892        1) The user subclasses list
893        2) None of the list methods are overriden, merely some new methods are added
894
895    In these cases, we can prevent graph breaks by not using the general
896    UserDefinedObjectVariable machinery and instead treating it like
897    a ListVariable.
898    """
899
900    _nonvar_fields = {"user_cls", "user_cls_source", *ListVariable._nonvar_fields}
901    _allowed_names = {
902        "__call__",
903        "__module__",
904        "__dict__",
905        "__doc__",
906        "__name__",
907        "__qualname__",
908    }
909    _disallowed_names = {
910        "__getattribute__",
911        "__getattr__",
912        "__setattr__",
913    }
914
915    @classmethod
916    def _is_non_conflicting_subclass(
917        cls,
918        user_cls: type,
919        python_cls: type,
920    ):
921        """Ensures user_cls inherits from python_cls (e.g. list) and does not override any methods on python_cls"""
922        if (
923            not istype(user_cls, type)
924            or user_cls.__bases__ != (python_cls,)
925            or user_cls.__mro__ != (user_cls, python_cls, object)
926        ):
927            return False  # not subclass
928        return not any(
929            hasattr(python_cls, name) or name in cls._disallowed_names
930            for name in set(user_cls.__dict__.keys()) - cls._allowed_names
931        )
932
933    @classmethod
934    def is_matching_cls(cls, user_cls: type):
935        return cls._is_non_conflicting_subclass(user_cls, list)
936
937    def __init__(
938        self, items, *, user_cls: type, user_cls_source: Source, **kwargs
939    ) -> None:
940        super().__init__(items=items, **kwargs)
941        self.user_cls = user_cls
942        self.user_cls_source = user_cls_source
943        assert istype(user_cls, type)
944        assert isinstance(user_cls_source, Source)
945
946    def debug_repr(self):
947        # The constructor is safe as no methods, including __init__, are
948        # allowed to be overridden
949        # NB: This is guaranteed to print like a list, as __repr__ cannot be
950        # overridden, this is... well, it's OK I guess (consistent with
951        # eager), but it could be misleading.  You will have to query type
952        # instead for details.
953        return repr(self.user_cls([Lit(x.debug_repr()) for x in self.items]))
954
955    def python_type(self):
956        return self.user_cls
957
958    def as_proxy(self):
959        return [x.as_proxy() for x in self.items]
960
961    def as_python_constant(self):
962        raise NotImplementedError
963
964    def is_python_constant(self):
965        return False
966
967    @property
968    def value(self):
969        raise AttributeError("value")
970
971    def modified(self, items, **kwargs):
972        return type(self)(
973            items,
974            user_cls=self.user_cls,
975            user_cls_source=self.user_cls_source,
976            **kwargs,
977        )
978
979    def reconstruct(self, codegen):
980        codegen.add_push_null(lambda: codegen(self.user_cls_source))
981        super().reconstruct(codegen)
982        codegen.extend_output(create_call_function(1, False))
983
984    def call_method(
985        self,
986        tx,
987        name,
988        args: List["VariableTracker"],
989        kwargs: Dict[str, "VariableTracker"],
990    ) -> "VariableTracker":
991        if name in self.user_cls.__dict__:
992            method = self.user_cls.__dict__[name]
993            if isinstance(method, types.FunctionType):
994                # inline the method
995                source = AttrSource(self.user_cls_source, name)
996                return UserMethodVariable(method, self, source=source).call_function(
997                    tx, args, kwargs
998                )
999            unimplemented(
1000                f"RestrictedListSubclassVariable method {self.user_cls.__name__}.{name}"
1001            )
1002        return super().call_method(tx, name, args, kwargs)
1003
1004    def call_function(
1005        self,
1006        tx: "InstructionTranslator",
1007        args: "List[VariableTracker]",
1008        kwargs: "Dict[str, VariableTracker]",
1009    ) -> "VariableTracker":
1010        return self.call_method(tx, "__call__", args, kwargs)
1011