xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/ctx_manager.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2import dataclasses
3import inspect
4import sys
5import warnings
6from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union
7
8import torch._C
9from torch._guards import Guard
10
11from .. import variables
12from ..bytecode_transformation import (
13    create_call_function,
14    create_instruction,
15    create_setup_with,
16)
17from ..device_interface import get_interface_for_device
18from ..exc import unimplemented, Unsupported
19from ..guards import GuardBuilder, install_guard
20from ..source import AttrSource, GlobalStateSource
21from .base import VariableTracker
22from .functions import (
23    NestedUserFunctionVariable,
24    UserFunctionVariable,
25    UserMethodVariable,
26    WrappedUserFunctionVariable,
27    WrappedUserMethodVariable,
28)
29from .user_defined import UserDefinedObjectVariable
30
31
32if TYPE_CHECKING:
33    from torch._dynamo.symbolic_convert import InstructionTranslator
34
35
36@dataclasses.dataclass
37class ContextMangerState:
38    """
39    Mutating `self` in VariableTracker is not allowed because we copy
40    them.  This is a mutable container pointed to by context managers
41    that won't get copied, so it is safe to mutate.
42    """
43
44    cleanup_fn: Optional[Callable] = None
45    proxy: Optional[torch.fx.Proxy] = None
46
47    def cleanup(self):
48        if self.cleanup_fn is not None:
49            self.cleanup_fn()
50            self.cleanup_fn = None
51
52    def cleanup_assert(self):
53        assert self.cleanup_fn, "multiple exits?"
54        self.cleanup()
55
56
57class ContextWrappingVariable(VariableTracker):
58    _nonvar_fields = {
59        "cm_obj",
60        "target_values",
61        "initial_values",
62        "state",
63        *VariableTracker._nonvar_fields,
64    }
65
66    def __init__(
67        self, target_values, initial_values=None, *, state=None, **kwargs
68    ) -> None:
69        super().__init__(**kwargs)
70        self.target_values = target_values
71        self.initial_values = initial_values
72        self.state = ContextMangerState() if state is None else state
73
74    def enter(self, tx):
75        self._call_func(tx, self.target_values)
76        self.set_cleanup_hook(tx)
77        return variables.ConstantVariable.create(None)
78
79    def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None):
80        if fn is None:
81
82            def fn():
83                self._call_func(tx, self.initial_values)
84
85        self.state.cleanup_fn = fn
86        tx.output.add_cleanup_hook(self.state.cleanup)
87
88    def exit(self, tx: "InstructionTranslator", *args):
89        self.state.cleanup_assert()
90        return variables.ConstantVariable.create(None)
91
92    def reconstruct_type(self, codegen):
93        codegen(
94            AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name())
95        )
96
97    def reconstruct(self, codegen):
98        codegen.add_push_null(lambda: self.reconstruct_type(codegen))
99        target_values = self.target_values
100        if not target_values:
101            target_values = ()
102        codegen.extend_output([codegen.create_load_const(val) for val in target_values])
103        codegen.extend_output(create_call_function(len(target_values), False))
104
105    def module_name(self):
106        raise NotImplementedError("module_name called on base")
107
108    def fn_name(self):
109        raise NotImplementedError("fn_name called on base")
110
111    def call_function(
112        self,
113        tx: "InstructionTranslator",
114        args: "List[VariableTracker]",
115        kwargs: "Dict[str, VariableTracker]",
116    ) -> "VariableTracker":
117        assert len(args) == 1
118        if isinstance(args[0], NestedUserFunctionVariable):
119            args[0] = UserFunctionVariable(args[0].get_function())
120        assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable))
121
122        if isinstance(args[0], UserMethodVariable):
123            return WrappedUserMethodVariable(args[0], self)
124
125        if isinstance(args[0], UserFunctionVariable):
126            return WrappedUserFunctionVariable(args[0], self)
127
128
129class GenericContextWrappingVariable(UserDefinedObjectVariable):
130    # Some methods in ContextWrappingVariable assumes the arguments are
131    # python contants. Which might not always be the case here.
132    def __init__(self, cm_obj, **kwargs) -> None:
133        assert cm_obj is not None
134        super().__init__(
135            value=cm_obj,
136            value_type=cm_obj.__class__,
137            **kwargs,
138        )
139        self.cm_obj = cm_obj
140
141    def module_name(self):
142        return self.cm_obj.__module__
143
144    def fn_name(self):
145        return type(self.cm_obj).__name__
146
147    def enter(self, tx):
148        source = None if self.source is None else AttrSource(self.source, "__enter__")
149        try:
150            return variables.UserMethodVariable(
151                self.cm_obj.__enter__.__func__,
152                self,
153                source=source,
154            ).call_function(tx, [], {})
155        except Unsupported as e:
156            unimplemented(
157                f"Unsupported context manager {self.cm_obj}'s __enter__ function",
158                from_exc=e,
159            )
160
161    def exit(self, tx: "InstructionTranslator", *args):
162        source = None if self.source is None else AttrSource(self.source, "__exit__")
163        try:
164            x = variables.UserMethodVariable(
165                self.cm_obj.__exit__.__func__,
166                self,
167                source=source,
168            ).call_function(
169                tx,
170                [
171                    variables.ConstantVariable.create(None),
172                    variables.ConstantVariable.create(None),
173                    variables.ConstantVariable.create(None),
174                ],
175                {},
176            )
177        except Unsupported as e:
178            unimplemented(
179                f"Unsupported context manager {self.cm_obj}'s __exit__ function",
180                from_exc=e,
181            )
182
183        tx.generic_context_manager_depth -= 1
184        return x
185
186
187class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
188    """represents torch grad requries grad"""
189
190    @staticmethod
191    def create(tx: "InstructionTranslator", target_values, **kwargs):
192        return GradInplaceRequiresGradCtxManagerVariable(
193            target_values=target_values,
194            initial_values=None,
195            **kwargs,
196        )
197
198    def enter(self, tx):
199        [enabled] = self.target_values
200        self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed()
201        torch._C._functorch.set_inplace_requires_grad_allowed(enabled)
202        self.set_cleanup_hook(
203            tx,
204            lambda: torch._C._functorch.set_inplace_requires_grad_allowed(
205                self.prev_state
206            ),
207        )
208        self.state.proxy = tx.output.create_node(
209            "call_function",
210            torch._C._functorch.set_inplace_requires_grad_allowed,
211            (enabled,),
212            {},
213        )
214        return variables.ConstantVariable.create(None)
215
216    def exit(self, tx: "InstructionTranslator", *args):
217        self.state.cleanup()
218        tx.output.create_node(
219            "call_function",
220            torch._C._functorch.set_inplace_requires_grad_allowed,
221            (self.prev_state,),
222            {},
223        )
224        return variables.ConstantVariable.create(None)
225
226
227class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable):
228    """represents torch.func.jvp increment/decrement nesting"""
229
230    # A guard is needed as the grad level is baked into the torch FX graph
231    # This is fine if jvp is only called from within the function
232    # being compiled. But the FX graph may be invalid in the case of a jvp
233    # call from eager that calls the compiled function, as the jvp levels
234    # may be different.
235    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
236
237    @staticmethod
238    def create(tx: "InstructionTranslator", **kwargs):
239        var = JvpIncrementNestingCtxManagerVariable(
240            target_values=None,
241            initial_values=None,
242            **kwargs,
243        )
244        return var
245
246    def enter(self, tx):
247        install_guard(self._guards_singleton)
248        jvp_level = torch._functorch.eager_transforms.enter_jvp_nesting()
249        self.set_cleanup_hook(
250            tx, lambda: torch._functorch.eager_transforms.exit_jvp_nesting()
251        )
252        self.state.proxy = tx.output.create_node(
253            "call_function",
254            torch._C._functorch._jvp_increment_nesting,
255            (),
256            {},
257        )
258        return variables.ConstantVariable.create(jvp_level)
259
260    def exit(self, tx: "InstructionTranslator", *args):
261        self.state.cleanup()
262        tx.output.create_node(
263            "call_function", torch._C._functorch._jvp_decrement_nesting, (), {}
264        )
265        return variables.ConstantVariable.create(None)
266
267
268class SetFwdGradEnabledContextManager(ContextWrappingVariable):
269    """represents torch.autograd.forward_ad._set_fwd_grad_enabled() to enable/disable fwd grad"""
270
271    @staticmethod
272    def create(tx: "InstructionTranslator", target_values, **kwargs):
273        return SetFwdGradEnabledContextManager(
274            target_values=target_values,
275            initial_values=None,
276            **kwargs,
277        )
278
279    def enter(self, tx):
280        [mode] = self.target_values
281        self.prev_state = torch._C._is_fwd_grad_enabled()
282        torch._C._set_fwd_grad_enabled(mode)
283        self.set_cleanup_hook(
284            tx,
285            lambda: torch._C._set_fwd_grad_enabled(self.prev_state),
286        )
287        self.state.proxy = tx.output.create_node(
288            "call_function",
289            torch._C._set_fwd_grad_enabled,
290            (mode,),
291            {},
292        )
293        return variables.ConstantVariable.create(None)
294
295    def exit(self, tx: "InstructionTranslator", *args):
296        self.state.cleanup()
297        tx.output.create_node(
298            "call_function",
299            torch._C._set_fwd_grad_enabled,
300            (self.prev_state,),
301            {},
302        )
303        return variables.ConstantVariable.create(None)
304
305
306class DualLevelContextManager(ContextWrappingVariable):
307    """Represents torch.autograd.forward_ad.dual_level ctx manager"""
308
309    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL)
310
311    @staticmethod
312    def create(tx: "InstructionTranslator", **kwargs):
313        return DualLevelContextManager(
314            target_values=None,
315            initial_values=None,
316            **kwargs,
317        )
318
319    def enter(self, tx):
320        install_guard(self._guards_singleton)
321        self.new_level = torch.autograd.forward_ad.enter_dual_level()
322        self.set_cleanup_hook(
323            tx, lambda: torch.autograd.forward_ad.exit_dual_level(level=self.new_level)
324        )
325        self.state.proxy = tx.output.create_node(
326            "call_function",
327            torch._C._enter_dual_level,
328            (),
329            {},
330        )
331        return variables.ConstantVariable.create(self.new_level)
332
333    def exit(self, tx: "InstructionTranslator", *args):
334        self.state.cleanup()
335        tx.output.create_node(
336            "call_function",
337            torch._C._exit_dual_level,
338            (self.new_level,),
339            {},
340        )
341        return variables.ConstantVariable.create(None)
342
343
344class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable):
345    """represents torch.func.grad increment/decrement nesting"""
346
347    # A guard is needed as the grad level is baked into the torch FX graph
348    # This is fine if grad is only called from within the function
349    # being compiled. But the FX graph may be invalid in the case of a grad
350    # call from eager that calls the compiled function, as the grad levels
351    # may be different.
352    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
353
354    @staticmethod
355    def create(tx: "InstructionTranslator", **kwargs):
356        var = GradIncrementNestingCtxManagerVariable(
357            target_values=None,
358            initial_values=None,
359            **kwargs,
360        )
361        return var
362
363    def enter(self, tx):
364        install_guard(self._guards_singleton)
365        grad_level = torch._C._functorch._grad_increment_nesting()
366        self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting())
367        self.state.proxy = tx.output.create_node(
368            "call_function",
369            torch._C._functorch._grad_increment_nesting,
370            (),
371            {},
372        )
373        return variables.ConstantVariable.create(grad_level)
374
375    def exit(self, tx: "InstructionTranslator", *args):
376        self.state.cleanup()
377        tx.output.create_node(
378            "call_function", torch._C._functorch._grad_decrement_nesting, (), {}
379        )
380        return variables.ConstantVariable.create(None)
381
382
383class CatchWarningsCtxManagerVariable(ContextWrappingVariable):
384    """Delay a call to warnings.catch_warnings"""
385
386    @staticmethod
387    def create(tx: "InstructionTranslator", catch_warnings_args):
388        return CatchWarningsCtxManagerVariable(
389            catch_warnings_args=catch_warnings_args,
390            target_values=None,
391            initial_values=None,
392        )
393
394    def __init__(self, catch_warnings_args, **kwargs) -> None:
395        assert isinstance(catch_warnings_args, dict), catch_warnings_args
396        super().__init__(**kwargs)
397        self.catch_warnings_args = catch_warnings_args
398
399    def enter(self, tx):
400        kwargs = {
401            k: v.as_python_constant() for k, v in self.catch_warnings_args.items()
402        }
403        ctx_val = warnings.catch_warnings(**kwargs)
404        self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None))
405        return variables.ConstantVariable.create(ctx_val.__enter__())
406
407    def reconstruct(self, cg):
408        cg.add_push_null(lambda: cg.load_import_from("warnings", "catch_warnings"))
409        cg.foreach(self.catch_warnings_args.values())
410        keys = tuple(self.catch_warnings_args.keys())
411        cg.extend_output(cg.create_call_function_kw(len(keys), keys, False))
412
413
414class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
415    """represents torch VMap increment/decrement nesting"""
416
417    # A guard is needed as the vmap level is baked into the torch FX graph
418    # generated. This is fine if vmap is only called from within the function
419    # being compiled. But the FX graph may be invalid in the case of a vmap
420    # call from eager that calls the compiled function, as the vmap levels
421    # may be different.
422    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
423
424    @staticmethod
425    def create(tx: "InstructionTranslator", target_values, **kwargs):
426        var = VmapIncrementNestingCtxManagerVariable(
427            target_values=target_values,
428            initial_values=None,
429            **kwargs,
430        )
431        return var
432
433    def enter(self, tx):
434        install_guard(self._guards_singleton)
435        batch_size, randomness = self.target_values
436        vmap_level = torch._C._functorch._vmap_increment_nesting(batch_size, randomness)
437        self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting())
438        self.state.proxy = tx.output.create_node(
439            "call_function",
440            torch._C._functorch._vmap_increment_nesting,
441            (batch_size, randomness),
442            {},
443        )
444        return variables.ConstantVariable.create(vmap_level)
445
446    def exit(self, tx: "InstructionTranslator", *args):
447        self.state.cleanup()
448        tx.output.create_node(
449            "call_function", torch._C._functorch._vmap_decrement_nesting, (), {}
450        )
451        return variables.ConstantVariable.create(None)
452
453
454class GradModeVariable(ContextWrappingVariable):
455    """represents torch.{no_grad,enable_grad,set_grad_mode}()"""
456
457    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE)
458
459    @staticmethod
460    def create(tx: "InstructionTranslator", target_value, initialized=False, **kwargs):
461        var = GradModeVariable(
462            target_values=[target_value],
463            initial_values=[torch.is_grad_enabled()],
464            **kwargs,
465        )
466        if initialized:
467            var._call_func(tx, var.target_values)
468        return var
469
470    def __init__(
471        self, target_values, initial_values=None, initialized=True, **kwargs
472    ) -> None:
473        super().__init__(
474            target_values=target_values, initial_values=initial_values, **kwargs
475        )
476        install_guard(self._guards_singleton)
477
478    def enter(self, tx):
479        self._call_func(tx, self.target_values)
480        return variables.ConstantVariable.create(None)
481
482    def exit(self, tx: "InstructionTranslator", *args):
483        self._call_func(tx, self.initial_values)
484        return variables.ConstantVariable.create(None)
485
486    def call_function(
487        self,
488        tx: "InstructionTranslator",
489        args: "List[VariableTracker]",
490        kwargs: "Dict[str, VariableTracker]",
491    ):
492        self._call_func(tx, self.initial_values)  # undo eager initialization
493        return super().call_function(tx, args, kwargs)
494
495    def _call_func(self, tx: "InstructionTranslator", values):
496        assert len(values) == 1
497        value = values[0]
498        # Coalesce grad mode mutations
499        if torch.is_grad_enabled() != value:
500            tx.output.create_node(
501                "call_function", torch._C._set_grad_enabled, (value,), {}
502            )
503            torch._C._set_grad_enabled(value)
504
505    def module_name(self):
506        return "torch"
507
508    def fn_name(self):
509        return "set_grad_enabled"
510
511
512class InferenceModeVariable(ContextWrappingVariable):
513    @staticmethod
514    def create(tx: "InstructionTranslator", target_value, **kwargs):
515        var = InferenceModeVariable(
516            [target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs
517        )
518        return var
519
520    def __init__(
521        self,
522        target_values,
523        initial_values=None,
524        **kwargs,
525    ) -> None:
526        if initial_values is None:
527            # This must be called here since function defaults are evaluated at import time
528            initial_values = torch.is_inference_mode_enabled()
529        super().__init__(
530            target_values=target_values, initial_values=initial_values, **kwargs
531        )
532        self.target_values = target_values
533
534    def exit(self, tx: "InstructionTranslator", *args):
535        self.state.cleanup_assert()
536        tx.output.create_node(
537            "call_function",
538            torch.autograd.grad_mode._exit_inference_mode,
539            (self.state.proxy,),
540            {},
541        )
542
543    def enter(self, tx):
544        ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values)
545        self.set_cleanup_hook(
546            tx, lambda: torch.autograd.grad_mode._exit_inference_mode(ctx)
547        )
548        self.state.proxy = tx.output.create_node(
549            "call_function",
550            torch.autograd.grad_mode._enter_inference_mode,
551            (*self.target_values,),
552            {},
553        )
554
555    def module_name(self):
556        return "torch"
557
558    def fn_name(self):
559        return "inference_mode"
560
561
562class CUDADeviceVariable(ContextWrappingVariable):
563    """represents torch.cuda.device"""
564
565    @staticmethod
566    def create(tx: "InstructionTranslator", device, **kwargs):
567        var = CUDADeviceVariable(
568            target_values=[torch.cuda._get_device_index(device, optional=True)],
569            initial_values=None,
570            **kwargs,
571        )
572        return var
573
574    def __init__(
575        self,
576        target_values,
577        initial_values=None,
578        **kwargs,
579    ) -> None:
580        super().__init__(
581            target_values=target_values, initial_values=initial_values, **kwargs
582        )
583        self.target_values = target_values
584
585    def exit(self, tx: "InstructionTranslator", *args):
586        self.state.cleanup_assert()
587        tx.output.create_node(
588            "call_function",
589            torch.cuda._maybe_exchange_device,
590            (self.state.proxy,),
591            {},
592        )
593        return variables.ConstantVariable.create(False)
594
595    def enter(self, tx):
596        prev_idx = torch.cuda._exchange_device(*self.target_values)
597        self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx))
598        self.state.proxy = tx.output.create_node(
599            "call_function",
600            torch.cuda._exchange_device,
601            (*self.target_values,),
602            {},
603        )
604
605    def module_name(self):
606        return "torch.cuda"
607
608    def fn_name(self):
609        return "device"
610
611
612class TorchFunctionDisableVariable(ContextWrappingVariable):
613    """represents whether torch function overrides are enabled or not"""
614
615    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE)
616
617    @staticmethod
618    def create(tx: "InstructionTranslator", **kwargs):
619        var = TorchFunctionDisableVariable(
620            target_values=[False],
621            initial_values=[tx.output.torch_function_enabled],
622            **kwargs,
623        )
624        # mlazos: I think this is here to make sure we don't reinvoke on clone()
625        var._call_func(tx, [False])
626        var.set_cleanup_hook(tx)
627        return var
628
629    def __init__(self, target_values, initial_values=None, **kwargs) -> None:
630        super().__init__(
631            target_values=target_values, initial_values=initial_values, **kwargs
632        )
633        install_guard(self._guards_singleton)
634
635    def enter(self, tx):
636        return variables.ConstantVariable.create(None)
637
638    def _call_func(self, tx: "InstructionTranslator", values):
639        assert len(values) == 1
640        tx.output.set_torch_function_state(values[0])
641
642
643class DeterministicAlgorithmsVariable(ContextWrappingVariable):
644    """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""
645
646    _guards_singleton = Guard(
647        GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS
648    )
649
650    @staticmethod
651    def create(tx: "InstructionTranslator", target_value, **kwargs):
652        var = DeterministicAlgorithmsVariable(
653            target_values=[target_value],
654            initial_values=[torch.are_deterministic_algorithms_enabled()],
655            **kwargs,
656        )
657        var._call_func(tx, [target_value])
658        var.set_cleanup_hook(tx)
659        return var
660
661    def __init__(self, target_values, initial_values=None, **kwargs) -> None:
662        super().__init__(
663            target_values=target_values, initial_values=initial_values, **kwargs
664        )
665        install_guard(self._guards_singleton)
666
667    def enter(self, tx):
668        return variables.ConstantVariable.create(None)
669
670    def _call_func(self, tx: "InstructionTranslator", values):
671        assert len(values) == 1
672        value = values[0]
673        tx.output.create_node(
674            "call_function", torch._C._set_deterministic_algorithms, (value,), {}
675        ),
676        torch._C._set_deterministic_algorithms(value)
677
678    def module_name(self):
679        return "torch"
680
681    def fn_name(self):
682        return "use_deterministic_algorithms"
683
684
685class DisabledSavedTensorsHooksVariable(ContextWrappingVariable):
686    """represents torch.autograd.graph.disable_saved_tensors_hook."""
687
688    @staticmethod
689    def create(tx: "InstructionTranslator", target_value, **kwargs):
690        var = DisabledSavedTensorsHooksVariable(
691            target_values=[target_value],
692            initial_values=[
693                torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
694            ],
695            **kwargs,
696        )
697        var._call_func(tx, [target_value])
698        var.set_cleanup_hook(tx)
699        return var
700
701    def __init__(self, target_values, initial_values=None, **kwargs) -> None:
702        super().__init__(
703            target_values=target_values, initial_values=initial_values, **kwargs
704        )
705
706    def enter(self, tx):
707        return variables.ConstantVariable.create(None)
708
709    def _call_func(self, tx: "InstructionTranslator", values):
710        assert len(values) == 1
711        value = values[0]
712        if value is not None:
713            # Disable `saved_tensors_hooks` with message (`value`)
714            # OR
715            # we are exiting this context and restoring the previous message.
716            tx.output.create_node(
717                "call_function",
718                torch._C._autograd._saved_tensors_hooks_disable,
719                (value,),
720                {},
721            )
722            torch._C._autograd._saved_tensors_hooks_disable(value)
723        else:
724            # We are exiting this context and if prev_message was None, we re-enable `saved_tensors_hooks`.
725            tx.output.create_node(
726                "call_function", torch._C._autograd._saved_tensors_hooks_enable, (), {}
727            )
728            torch._C._autograd._saved_tensors_hooks_enable()
729
730    def module_name(self):
731        return "torch.autograd.graph"
732
733    def fn_name(self):
734        return "disable_saved_tensors_hooks"
735
736
737class AutocastModeVariable(ContextWrappingVariable):
738    @staticmethod
739    def create(func, args, kwargs):
740        assert func in [
741            torch.amp.autocast_mode.autocast,
742            torch.cuda.amp.autocast,
743            torch.cpu.amp.autocast,
744        ]
745        # device_type : str,
746        # dtype : Optional[_dtype] = None,
747        # enabled : bool = True,
748        # cache_enabled : Optional[bool] = None):cache_enabled
749        bound_args = inspect.signature(func).bind(*args, **kwargs)
750        bound_args.apply_defaults()
751        target_values = []
752        kwargs.clear()
753
754        for key in ["device_type", "dtype", "enabled", "cache_enabled"]:
755            if key == "device_type" and func in [
756                torch.cuda.amp.autocast,
757                torch.cpu.amp.autocast,
758            ]:
759                arg = "cuda" if func is torch.cuda.amp.autocast else "cpu"
760            else:
761                arg = bound_args.arguments[key]
762            if isinstance(arg, VariableTracker):
763                target_values.append(arg.as_python_constant())
764            else:
765                target_values.append(arg)
766
767        var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
768        return var
769
770    def __init__(self, target_values, initial_values=None, **kwargs) -> None:
771        super().__init__(
772            target_values=target_values, initial_values=initial_values, **kwargs
773        )
774        self.target_values = target_values
775
776    def exit(self, tx: "InstructionTranslator", *args):
777        self.state.cleanup_assert()
778        tx.output.create_node(
779            "call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
780        )
781
782    def enter(self, tx):
783        ctx = torch.amp._enter_autocast(*self.target_values)
784        self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx))
785        self.state.proxy = tx.output.create_node(
786            "call_function", torch.amp._enter_autocast, (*self.target_values,), {}
787        )
788
789    def module_name(self):
790        return "torch.amp.autocast_mode"
791
792    def fn_name(self):
793        return "autocast"
794
795
796class NullContextVariable(ContextWrappingVariable):
797    """
798    This class represents Python contextlib.nullcontext.
799    It's used as a placeholder for other context managers that Dynamo doesn't
800    support yet, e.g, torch.autograd.profiler.record_function.
801    """
802
803    def __init__(self, target_values=None, **kwargs) -> None:
804        super().__init__(target_values=target_values, **kwargs)
805
806    def enter(self, tx):
807        return variables.ConstantVariable.create(None)
808
809    def exit(self, tx: "InstructionTranslator", *args):
810        return variables.ConstantVariable.create(None)
811
812    def module_name(self):
813        return "contextlib"
814
815    def fn_name(self):
816        return "nullcontext"
817
818
819class StreamContextVariable(ContextWrappingVariable):
820    @staticmethod
821    def create(tx: "InstructionTranslator", target_value, **kwargs):
822        from .builder import wrap_fx_proxy_cls
823
824        current_stream_method = get_interface_for_device(
825            target_value.device
826        ).current_stream
827        current_stream = wrap_fx_proxy_cls(
828            StreamVariable,
829            tx,
830            tx.output.create_proxy(
831                "call_function",
832                current_stream_method,
833                (None,),
834                {},
835            ),
836        )
837        return StreamContextVariable(
838            target_values=[target_value],
839            initial_values=[current_stream],
840            device=target_value.device,
841            **kwargs,
842        )
843
844    def __init__(self, target_values, device, initial_values=None, **kwargs) -> None:
845        super().__init__(
846            target_values=target_values, initial_values=initial_values, **kwargs
847        )
848        self.device = device
849        self.set_stream = get_interface_for_device(self.device).set_stream
850        self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
851
852    def enter(self, tx):
853        # stream generated inside the traced function
854        if self.target_values[0].as_proxy() is not None:
855            tx.output.create_proxy(
856                "call_function",
857                self.set_stream,
858                (self.target_values[0].as_proxy(),),
859                {},
860            )
861        # stream passed from outside the traced function
862        else:
863            stream = self.target_values[0].value
864            tx.output.create_proxy(
865                "call_function",
866                self.set_stream_id,
867                (stream.stream_id, stream.device_index, stream.device_type),
868                {},
869            )
870        self.set_stream(self.target_values[0].value)
871        self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
872
873    def exit(self, tx: "InstructionTranslator", *args):
874        tx.output.create_proxy(
875            "call_function",
876            self.set_stream,
877            (self.initial_values[0].as_proxy(),),
878            {},
879        )
880        self.state.cleanup_assert()
881
882
883class PreserveVersionContextVariable(ContextWrappingVariable):
884    """
885    Wraps torch.autograd._unsafe_preserve_version_counter
886    """
887
888    @staticmethod
889    def constructor(tx):
890        return variables.LambdaVariable(
891            lambda tensor: PreserveVersionContextVariable(
892                tensor,
893                tensor.var_getattr(tx, "_version"),
894            )
895        )
896
897    def __init__(self, tensor, prev_version, **kwargs) -> None:
898        kwargs.setdefault("target_values", None)
899        super().__init__(**kwargs)
900        self.tensor = tensor
901        self.prev_version = prev_version
902
903    def enter(self, tx):
904        pass
905
906    def exit(self, tx: "InstructionTranslator", *args):
907        from ..tensor_version_op import _unsafe_set_version_counter
908
909        return variables.TorchInGraphFunctionVariable(
910            _unsafe_set_version_counter
911        ).call_function(tx, [self.tensor, self.prev_version], {})
912
913    def reconstruct(self, codegen):
914        unimplemented(
915            "torch.autograd._unsafe_preserve_version_counter with graph break"
916        )
917
918
919class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable):
920    _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE)
921
922    @staticmethod
923    def create(tx: "InstructionTranslator", param_group_var, target_value, **kwargs):
924        var = FSDPParamGroupUseTrainingStateVariable(
925            param_group_var=param_group_var,
926            target_values=[target_value],
927            initial_values=[param_group_var.value._training_state],
928            **kwargs,
929        )
930        return var
931
932    def __init__(
933        self, param_group_var, target_values, initial_values=None, **kwargs
934    ) -> None:
935        super().__init__(
936            target_values=target_values, initial_values=initial_values, **kwargs
937        )
938        self.param_group_var = param_group_var
939        install_guard(self._guards_singleton)
940
941    def enter(self, tx):
942        self._call_func(tx, self.target_values)
943        return variables.ConstantVariable.create(None)
944
945    def exit(self, tx: "InstructionTranslator", *args):
946        self._call_func(tx, self.initial_values)
947        return variables.ConstantVariable.create(None)
948
949    def call_function(
950        self,
951        tx: "InstructionTranslator",
952        args: "List[VariableTracker]",
953        kwargs: "Dict[str, VariableTracker]",
954    ):
955        self._call_func(tx, self.initial_values)  # undo eager initialization
956        return super().call_function(tx, args, kwargs)
957
958    def _call_func(self, tx: "InstructionTranslator", values):
959        assert len(values) == 1
960        value = values[0]
961        if self.param_group_var.value._training_state != value:
962            self.param_group_var.call_method(
963                tx,
964                "__setattr__",
965                (
966                    variables.ConstantVariable.create("_training_state"),
967                    variables.EnumVariable(value),
968                ),
969                {},
970            )
971            self.param_group_var.value._training_state = value
972
973    def module_name(self):
974        return "torch.distributed._composable.fsdp._fsdp_param_group.FSDPParamGroup"
975
976    def fn_name(self):
977        return "use_training_state"
978
979
980class StreamVariable(VariableTracker):
981    def __init__(self, proxy, value, device, **kwargs) -> None:
982        if proxy is not None and "example_value" in proxy.node.meta:
983            assert proxy.node.meta["example_value"] == value
984        assert (
985            value.device.type == device.type
986        ), "stream value is not equal to the passed device"
987        super().__init__(**kwargs)
988        self.proxy = proxy
989        self.value = value
990        self.device = device
991
992    def call_method(
993        self,
994        tx,
995        name,
996        args: "List[VariableTracker]",
997        kwargs: "Dict[str, VariableTracker]",
998    ) -> "VariableTracker":
999        assert hasattr(self.value, name), f"no stream method found named {name}"
1000        assert name in [
1001            "wait_stream",
1002            "synchronize",
1003            "query",
1004            "record_event",
1005            "wait_event",
1006        ], f" unsupported stream method {name}"
1007
1008        from ..utils import proxy_args_kwargs
1009        from .builder import wrap_fx_proxy_cls
1010
1011        if name in ("wait_stream", "synchronize", "wait_event"):
1012            tx.output.create_proxy(
1013                "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
1014            )
1015            return variables.ConstantVariable(None)
1016        elif name == "query":
1017            return wrap_fx_proxy_cls(
1018                target_cls=variables.ConstantVariable,
1019                tx=tx,
1020                proxy=tx.output.create_proxy(
1021                    "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
1022                ),
1023            )
1024        elif name == "record_event":
1025            return wrap_fx_proxy_cls(
1026                target_cls=EventVariable,
1027                tx=tx,
1028                proxy=tx.output.create_proxy(
1029                    "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
1030                ),
1031            )
1032        else:
1033            unimplemented(self.device + " stream method " + name + " unsupported")
1034
1035    def as_proxy(self):
1036        return self.proxy
1037
1038    def reconstruct(self, codegen):
1039        # If we got here, this stream is fully subsumed by the graph - this means it is
1040        # not an input or global
1041        assert not self.source
1042        # Since we just proved that - for other such structures, like lists and dicts, reconstruction
1043        # is fine and sound according to dynamo principles of treating collectives. However,
1044        # streams are special in that we want to preserve the identity of the stream as the same as in the graph
1045        # Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
1046        # yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
1047        # design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
1048        prefix = f"_stream_{self.device}"
1049        name = codegen.tx.output.install_global_by_id(prefix, self.value)
1050        codegen.append_output(codegen.create_load_global(name, add=True))
1051
1052
1053class EventVariable(VariableTracker):
1054    def __init__(self, proxy, value, **kwargs) -> None:
1055        if proxy is not None and "example_value" in proxy.node.meta:
1056            assert proxy.node.meta["example_value"] == value
1057        super().__init__(**kwargs)
1058        self.proxy = proxy
1059        self.value = value
1060
1061    def call_method(
1062        self,
1063        tx,
1064        name,
1065        args: "List[VariableTracker]",
1066        kwargs: "Dict[str, VariableTracker]",
1067    ) -> "VariableTracker":
1068        from ..utils import proxy_args_kwargs
1069        from .builder import wrap_fx_proxy_cls
1070
1071        if name in ("wait", "record", "synchronize"):
1072            tx.output.create_proxy(
1073                "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
1074            )
1075            return variables.ConstantVariable(None)
1076        elif name == "query":
1077            return wrap_fx_proxy_cls(
1078                target_cls=variables.ConstantVariable,
1079                tx=tx,
1080                proxy=tx.output.create_proxy(
1081                    "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
1082                ),
1083            )
1084        else:
1085            unimplemented(f"event method {name} unsupported")
1086
1087    def as_proxy(self):
1088        return self.proxy
1089
1090    def reconstruct(self, codegen):
1091        # If we got here, this event is fully subsumed by the graph - this means it is
1092        # not an input or global
1093        assert not self.source
1094        # Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
1095        prefix = "_event"
1096        name = codegen.tx.output.install_global_by_id(prefix, self.value)
1097        codegen.append_output(codegen.create_load_global(name, add=True))
1098
1099
1100class WithExitFunctionVariable(VariableTracker):
1101    _nonvar_fields = {
1102        "target",
1103        *VariableTracker._nonvar_fields,
1104    }
1105
1106    def __init__(
1107        self,
1108        ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable],
1109        target,
1110        **kwargs,
1111    ) -> None:
1112        super().__init__(**kwargs)
1113        assert isinstance(
1114            ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
1115        )
1116        self.ctx = ctx
1117        self.target = target
1118
1119    def call_function(
1120        self,
1121        tx: "InstructionTranslator",
1122        args: "List[VariableTracker]",
1123        kwargs: "Dict[str, VariableTracker]",
1124    ) -> "VariableTracker":
1125        assert not kwargs
1126        return self.ctx.exit(tx, *args)
1127
1128    def reconstruct(self, codegen):
1129        # Note here we reconstruct the context manager rather than the
1130        # exit function.  The handler generated by BlockStackEntry
1131        # will re-enter the context in the resume function.
1132        self.ctx.reconstruct_type(codegen)
1133        if codegen.tx.output.partial_convert:
1134            if sys.version_info >= (3, 11):
1135                codegen.append_output(create_instruction("PUSH_NULL"))
1136                if sys.version_info < (3, 13):
1137                    codegen.append_output(create_instruction("SWAP", arg=2))
1138            codegen.extend_output(
1139                [codegen.create_load_const(val) for val in self.ctx.target_values]
1140            )
1141            codegen.extend_output(
1142                create_call_function(len(self.ctx.target_values), False)
1143            )
1144            codegen.append_output(create_setup_with(self.target))
1145            codegen.append_output(create_instruction("POP_TOP"))
1146