1# mypy: ignore-errors 2import dataclasses 3import inspect 4import sys 5import warnings 6from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union 7 8import torch._C 9from torch._guards import Guard 10 11from .. import variables 12from ..bytecode_transformation import ( 13 create_call_function, 14 create_instruction, 15 create_setup_with, 16) 17from ..device_interface import get_interface_for_device 18from ..exc import unimplemented, Unsupported 19from ..guards import GuardBuilder, install_guard 20from ..source import AttrSource, GlobalStateSource 21from .base import VariableTracker 22from .functions import ( 23 NestedUserFunctionVariable, 24 UserFunctionVariable, 25 UserMethodVariable, 26 WrappedUserFunctionVariable, 27 WrappedUserMethodVariable, 28) 29from .user_defined import UserDefinedObjectVariable 30 31 32if TYPE_CHECKING: 33 from torch._dynamo.symbolic_convert import InstructionTranslator 34 35 36@dataclasses.dataclass 37class ContextMangerState: 38 """ 39 Mutating `self` in VariableTracker is not allowed because we copy 40 them. This is a mutable container pointed to by context managers 41 that won't get copied, so it is safe to mutate. 42 """ 43 44 cleanup_fn: Optional[Callable] = None 45 proxy: Optional[torch.fx.Proxy] = None 46 47 def cleanup(self): 48 if self.cleanup_fn is not None: 49 self.cleanup_fn() 50 self.cleanup_fn = None 51 52 def cleanup_assert(self): 53 assert self.cleanup_fn, "multiple exits?" 54 self.cleanup() 55 56 57class ContextWrappingVariable(VariableTracker): 58 _nonvar_fields = { 59 "cm_obj", 60 "target_values", 61 "initial_values", 62 "state", 63 *VariableTracker._nonvar_fields, 64 } 65 66 def __init__( 67 self, target_values, initial_values=None, *, state=None, **kwargs 68 ) -> None: 69 super().__init__(**kwargs) 70 self.target_values = target_values 71 self.initial_values = initial_values 72 self.state = ContextMangerState() if state is None else state 73 74 def enter(self, tx): 75 self._call_func(tx, self.target_values) 76 self.set_cleanup_hook(tx) 77 return variables.ConstantVariable.create(None) 78 79 def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None): 80 if fn is None: 81 82 def fn(): 83 self._call_func(tx, self.initial_values) 84 85 self.state.cleanup_fn = fn 86 tx.output.add_cleanup_hook(self.state.cleanup) 87 88 def exit(self, tx: "InstructionTranslator", *args): 89 self.state.cleanup_assert() 90 return variables.ConstantVariable.create(None) 91 92 def reconstruct_type(self, codegen): 93 codegen( 94 AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name()) 95 ) 96 97 def reconstruct(self, codegen): 98 codegen.add_push_null(lambda: self.reconstruct_type(codegen)) 99 target_values = self.target_values 100 if not target_values: 101 target_values = () 102 codegen.extend_output([codegen.create_load_const(val) for val in target_values]) 103 codegen.extend_output(create_call_function(len(target_values), False)) 104 105 def module_name(self): 106 raise NotImplementedError("module_name called on base") 107 108 def fn_name(self): 109 raise NotImplementedError("fn_name called on base") 110 111 def call_function( 112 self, 113 tx: "InstructionTranslator", 114 args: "List[VariableTracker]", 115 kwargs: "Dict[str, VariableTracker]", 116 ) -> "VariableTracker": 117 assert len(args) == 1 118 if isinstance(args[0], NestedUserFunctionVariable): 119 args[0] = UserFunctionVariable(args[0].get_function()) 120 assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable)) 121 122 if isinstance(args[0], UserMethodVariable): 123 return WrappedUserMethodVariable(args[0], self) 124 125 if isinstance(args[0], UserFunctionVariable): 126 return WrappedUserFunctionVariable(args[0], self) 127 128 129class GenericContextWrappingVariable(UserDefinedObjectVariable): 130 # Some methods in ContextWrappingVariable assumes the arguments are 131 # python contants. Which might not always be the case here. 132 def __init__(self, cm_obj, **kwargs) -> None: 133 assert cm_obj is not None 134 super().__init__( 135 value=cm_obj, 136 value_type=cm_obj.__class__, 137 **kwargs, 138 ) 139 self.cm_obj = cm_obj 140 141 def module_name(self): 142 return self.cm_obj.__module__ 143 144 def fn_name(self): 145 return type(self.cm_obj).__name__ 146 147 def enter(self, tx): 148 source = None if self.source is None else AttrSource(self.source, "__enter__") 149 try: 150 return variables.UserMethodVariable( 151 self.cm_obj.__enter__.__func__, 152 self, 153 source=source, 154 ).call_function(tx, [], {}) 155 except Unsupported as e: 156 unimplemented( 157 f"Unsupported context manager {self.cm_obj}'s __enter__ function", 158 from_exc=e, 159 ) 160 161 def exit(self, tx: "InstructionTranslator", *args): 162 source = None if self.source is None else AttrSource(self.source, "__exit__") 163 try: 164 x = variables.UserMethodVariable( 165 self.cm_obj.__exit__.__func__, 166 self, 167 source=source, 168 ).call_function( 169 tx, 170 [ 171 variables.ConstantVariable.create(None), 172 variables.ConstantVariable.create(None), 173 variables.ConstantVariable.create(None), 174 ], 175 {}, 176 ) 177 except Unsupported as e: 178 unimplemented( 179 f"Unsupported context manager {self.cm_obj}'s __exit__ function", 180 from_exc=e, 181 ) 182 183 tx.generic_context_manager_depth -= 1 184 return x 185 186 187class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): 188 """represents torch grad requries grad""" 189 190 @staticmethod 191 def create(tx: "InstructionTranslator", target_values, **kwargs): 192 return GradInplaceRequiresGradCtxManagerVariable( 193 target_values=target_values, 194 initial_values=None, 195 **kwargs, 196 ) 197 198 def enter(self, tx): 199 [enabled] = self.target_values 200 self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed() 201 torch._C._functorch.set_inplace_requires_grad_allowed(enabled) 202 self.set_cleanup_hook( 203 tx, 204 lambda: torch._C._functorch.set_inplace_requires_grad_allowed( 205 self.prev_state 206 ), 207 ) 208 self.state.proxy = tx.output.create_node( 209 "call_function", 210 torch._C._functorch.set_inplace_requires_grad_allowed, 211 (enabled,), 212 {}, 213 ) 214 return variables.ConstantVariable.create(None) 215 216 def exit(self, tx: "InstructionTranslator", *args): 217 self.state.cleanup() 218 tx.output.create_node( 219 "call_function", 220 torch._C._functorch.set_inplace_requires_grad_allowed, 221 (self.prev_state,), 222 {}, 223 ) 224 return variables.ConstantVariable.create(None) 225 226 227class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable): 228 """represents torch.func.jvp increment/decrement nesting""" 229 230 # A guard is needed as the grad level is baked into the torch FX graph 231 # This is fine if jvp is only called from within the function 232 # being compiled. But the FX graph may be invalid in the case of a jvp 233 # call from eager that calls the compiled function, as the jvp levels 234 # may be different. 235 _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) 236 237 @staticmethod 238 def create(tx: "InstructionTranslator", **kwargs): 239 var = JvpIncrementNestingCtxManagerVariable( 240 target_values=None, 241 initial_values=None, 242 **kwargs, 243 ) 244 return var 245 246 def enter(self, tx): 247 install_guard(self._guards_singleton) 248 jvp_level = torch._functorch.eager_transforms.enter_jvp_nesting() 249 self.set_cleanup_hook( 250 tx, lambda: torch._functorch.eager_transforms.exit_jvp_nesting() 251 ) 252 self.state.proxy = tx.output.create_node( 253 "call_function", 254 torch._C._functorch._jvp_increment_nesting, 255 (), 256 {}, 257 ) 258 return variables.ConstantVariable.create(jvp_level) 259 260 def exit(self, tx: "InstructionTranslator", *args): 261 self.state.cleanup() 262 tx.output.create_node( 263 "call_function", torch._C._functorch._jvp_decrement_nesting, (), {} 264 ) 265 return variables.ConstantVariable.create(None) 266 267 268class SetFwdGradEnabledContextManager(ContextWrappingVariable): 269 """represents torch.autograd.forward_ad._set_fwd_grad_enabled() to enable/disable fwd grad""" 270 271 @staticmethod 272 def create(tx: "InstructionTranslator", target_values, **kwargs): 273 return SetFwdGradEnabledContextManager( 274 target_values=target_values, 275 initial_values=None, 276 **kwargs, 277 ) 278 279 def enter(self, tx): 280 [mode] = self.target_values 281 self.prev_state = torch._C._is_fwd_grad_enabled() 282 torch._C._set_fwd_grad_enabled(mode) 283 self.set_cleanup_hook( 284 tx, 285 lambda: torch._C._set_fwd_grad_enabled(self.prev_state), 286 ) 287 self.state.proxy = tx.output.create_node( 288 "call_function", 289 torch._C._set_fwd_grad_enabled, 290 (mode,), 291 {}, 292 ) 293 return variables.ConstantVariable.create(None) 294 295 def exit(self, tx: "InstructionTranslator", *args): 296 self.state.cleanup() 297 tx.output.create_node( 298 "call_function", 299 torch._C._set_fwd_grad_enabled, 300 (self.prev_state,), 301 {}, 302 ) 303 return variables.ConstantVariable.create(None) 304 305 306class DualLevelContextManager(ContextWrappingVariable): 307 """Represents torch.autograd.forward_ad.dual_level ctx manager""" 308 309 _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL) 310 311 @staticmethod 312 def create(tx: "InstructionTranslator", **kwargs): 313 return DualLevelContextManager( 314 target_values=None, 315 initial_values=None, 316 **kwargs, 317 ) 318 319 def enter(self, tx): 320 install_guard(self._guards_singleton) 321 self.new_level = torch.autograd.forward_ad.enter_dual_level() 322 self.set_cleanup_hook( 323 tx, lambda: torch.autograd.forward_ad.exit_dual_level(level=self.new_level) 324 ) 325 self.state.proxy = tx.output.create_node( 326 "call_function", 327 torch._C._enter_dual_level, 328 (), 329 {}, 330 ) 331 return variables.ConstantVariable.create(self.new_level) 332 333 def exit(self, tx: "InstructionTranslator", *args): 334 self.state.cleanup() 335 tx.output.create_node( 336 "call_function", 337 torch._C._exit_dual_level, 338 (self.new_level,), 339 {}, 340 ) 341 return variables.ConstantVariable.create(None) 342 343 344class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable): 345 """represents torch.func.grad increment/decrement nesting""" 346 347 # A guard is needed as the grad level is baked into the torch FX graph 348 # This is fine if grad is only called from within the function 349 # being compiled. But the FX graph may be invalid in the case of a grad 350 # call from eager that calls the compiled function, as the grad levels 351 # may be different. 352 _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) 353 354 @staticmethod 355 def create(tx: "InstructionTranslator", **kwargs): 356 var = GradIncrementNestingCtxManagerVariable( 357 target_values=None, 358 initial_values=None, 359 **kwargs, 360 ) 361 return var 362 363 def enter(self, tx): 364 install_guard(self._guards_singleton) 365 grad_level = torch._C._functorch._grad_increment_nesting() 366 self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting()) 367 self.state.proxy = tx.output.create_node( 368 "call_function", 369 torch._C._functorch._grad_increment_nesting, 370 (), 371 {}, 372 ) 373 return variables.ConstantVariable.create(grad_level) 374 375 def exit(self, tx: "InstructionTranslator", *args): 376 self.state.cleanup() 377 tx.output.create_node( 378 "call_function", torch._C._functorch._grad_decrement_nesting, (), {} 379 ) 380 return variables.ConstantVariable.create(None) 381 382 383class CatchWarningsCtxManagerVariable(ContextWrappingVariable): 384 """Delay a call to warnings.catch_warnings""" 385 386 @staticmethod 387 def create(tx: "InstructionTranslator", catch_warnings_args): 388 return CatchWarningsCtxManagerVariable( 389 catch_warnings_args=catch_warnings_args, 390 target_values=None, 391 initial_values=None, 392 ) 393 394 def __init__(self, catch_warnings_args, **kwargs) -> None: 395 assert isinstance(catch_warnings_args, dict), catch_warnings_args 396 super().__init__(**kwargs) 397 self.catch_warnings_args = catch_warnings_args 398 399 def enter(self, tx): 400 kwargs = { 401 k: v.as_python_constant() for k, v in self.catch_warnings_args.items() 402 } 403 ctx_val = warnings.catch_warnings(**kwargs) 404 self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None)) 405 return variables.ConstantVariable.create(ctx_val.__enter__()) 406 407 def reconstruct(self, cg): 408 cg.add_push_null(lambda: cg.load_import_from("warnings", "catch_warnings")) 409 cg.foreach(self.catch_warnings_args.values()) 410 keys = tuple(self.catch_warnings_args.keys()) 411 cg.extend_output(cg.create_call_function_kw(len(keys), keys, False)) 412 413 414class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable): 415 """represents torch VMap increment/decrement nesting""" 416 417 # A guard is needed as the vmap level is baked into the torch FX graph 418 # generated. This is fine if vmap is only called from within the function 419 # being compiled. But the FX graph may be invalid in the case of a vmap 420 # call from eager that calls the compiled function, as the vmap levels 421 # may be different. 422 _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) 423 424 @staticmethod 425 def create(tx: "InstructionTranslator", target_values, **kwargs): 426 var = VmapIncrementNestingCtxManagerVariable( 427 target_values=target_values, 428 initial_values=None, 429 **kwargs, 430 ) 431 return var 432 433 def enter(self, tx): 434 install_guard(self._guards_singleton) 435 batch_size, randomness = self.target_values 436 vmap_level = torch._C._functorch._vmap_increment_nesting(batch_size, randomness) 437 self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting()) 438 self.state.proxy = tx.output.create_node( 439 "call_function", 440 torch._C._functorch._vmap_increment_nesting, 441 (batch_size, randomness), 442 {}, 443 ) 444 return variables.ConstantVariable.create(vmap_level) 445 446 def exit(self, tx: "InstructionTranslator", *args): 447 self.state.cleanup() 448 tx.output.create_node( 449 "call_function", torch._C._functorch._vmap_decrement_nesting, (), {} 450 ) 451 return variables.ConstantVariable.create(None) 452 453 454class GradModeVariable(ContextWrappingVariable): 455 """represents torch.{no_grad,enable_grad,set_grad_mode}()""" 456 457 _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE) 458 459 @staticmethod 460 def create(tx: "InstructionTranslator", target_value, initialized=False, **kwargs): 461 var = GradModeVariable( 462 target_values=[target_value], 463 initial_values=[torch.is_grad_enabled()], 464 **kwargs, 465 ) 466 if initialized: 467 var._call_func(tx, var.target_values) 468 return var 469 470 def __init__( 471 self, target_values, initial_values=None, initialized=True, **kwargs 472 ) -> None: 473 super().__init__( 474 target_values=target_values, initial_values=initial_values, **kwargs 475 ) 476 install_guard(self._guards_singleton) 477 478 def enter(self, tx): 479 self._call_func(tx, self.target_values) 480 return variables.ConstantVariable.create(None) 481 482 def exit(self, tx: "InstructionTranslator", *args): 483 self._call_func(tx, self.initial_values) 484 return variables.ConstantVariable.create(None) 485 486 def call_function( 487 self, 488 tx: "InstructionTranslator", 489 args: "List[VariableTracker]", 490 kwargs: "Dict[str, VariableTracker]", 491 ): 492 self._call_func(tx, self.initial_values) # undo eager initialization 493 return super().call_function(tx, args, kwargs) 494 495 def _call_func(self, tx: "InstructionTranslator", values): 496 assert len(values) == 1 497 value = values[0] 498 # Coalesce grad mode mutations 499 if torch.is_grad_enabled() != value: 500 tx.output.create_node( 501 "call_function", torch._C._set_grad_enabled, (value,), {} 502 ) 503 torch._C._set_grad_enabled(value) 504 505 def module_name(self): 506 return "torch" 507 508 def fn_name(self): 509 return "set_grad_enabled" 510 511 512class InferenceModeVariable(ContextWrappingVariable): 513 @staticmethod 514 def create(tx: "InstructionTranslator", target_value, **kwargs): 515 var = InferenceModeVariable( 516 [target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs 517 ) 518 return var 519 520 def __init__( 521 self, 522 target_values, 523 initial_values=None, 524 **kwargs, 525 ) -> None: 526 if initial_values is None: 527 # This must be called here since function defaults are evaluated at import time 528 initial_values = torch.is_inference_mode_enabled() 529 super().__init__( 530 target_values=target_values, initial_values=initial_values, **kwargs 531 ) 532 self.target_values = target_values 533 534 def exit(self, tx: "InstructionTranslator", *args): 535 self.state.cleanup_assert() 536 tx.output.create_node( 537 "call_function", 538 torch.autograd.grad_mode._exit_inference_mode, 539 (self.state.proxy,), 540 {}, 541 ) 542 543 def enter(self, tx): 544 ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values) 545 self.set_cleanup_hook( 546 tx, lambda: torch.autograd.grad_mode._exit_inference_mode(ctx) 547 ) 548 self.state.proxy = tx.output.create_node( 549 "call_function", 550 torch.autograd.grad_mode._enter_inference_mode, 551 (*self.target_values,), 552 {}, 553 ) 554 555 def module_name(self): 556 return "torch" 557 558 def fn_name(self): 559 return "inference_mode" 560 561 562class CUDADeviceVariable(ContextWrappingVariable): 563 """represents torch.cuda.device""" 564 565 @staticmethod 566 def create(tx: "InstructionTranslator", device, **kwargs): 567 var = CUDADeviceVariable( 568 target_values=[torch.cuda._get_device_index(device, optional=True)], 569 initial_values=None, 570 **kwargs, 571 ) 572 return var 573 574 def __init__( 575 self, 576 target_values, 577 initial_values=None, 578 **kwargs, 579 ) -> None: 580 super().__init__( 581 target_values=target_values, initial_values=initial_values, **kwargs 582 ) 583 self.target_values = target_values 584 585 def exit(self, tx: "InstructionTranslator", *args): 586 self.state.cleanup_assert() 587 tx.output.create_node( 588 "call_function", 589 torch.cuda._maybe_exchange_device, 590 (self.state.proxy,), 591 {}, 592 ) 593 return variables.ConstantVariable.create(False) 594 595 def enter(self, tx): 596 prev_idx = torch.cuda._exchange_device(*self.target_values) 597 self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx)) 598 self.state.proxy = tx.output.create_node( 599 "call_function", 600 torch.cuda._exchange_device, 601 (*self.target_values,), 602 {}, 603 ) 604 605 def module_name(self): 606 return "torch.cuda" 607 608 def fn_name(self): 609 return "device" 610 611 612class TorchFunctionDisableVariable(ContextWrappingVariable): 613 """represents whether torch function overrides are enabled or not""" 614 615 _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE) 616 617 @staticmethod 618 def create(tx: "InstructionTranslator", **kwargs): 619 var = TorchFunctionDisableVariable( 620 target_values=[False], 621 initial_values=[tx.output.torch_function_enabled], 622 **kwargs, 623 ) 624 # mlazos: I think this is here to make sure we don't reinvoke on clone() 625 var._call_func(tx, [False]) 626 var.set_cleanup_hook(tx) 627 return var 628 629 def __init__(self, target_values, initial_values=None, **kwargs) -> None: 630 super().__init__( 631 target_values=target_values, initial_values=initial_values, **kwargs 632 ) 633 install_guard(self._guards_singleton) 634 635 def enter(self, tx): 636 return variables.ConstantVariable.create(None) 637 638 def _call_func(self, tx: "InstructionTranslator", values): 639 assert len(values) == 1 640 tx.output.set_torch_function_state(values[0]) 641 642 643class DeterministicAlgorithmsVariable(ContextWrappingVariable): 644 """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()""" 645 646 _guards_singleton = Guard( 647 GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS 648 ) 649 650 @staticmethod 651 def create(tx: "InstructionTranslator", target_value, **kwargs): 652 var = DeterministicAlgorithmsVariable( 653 target_values=[target_value], 654 initial_values=[torch.are_deterministic_algorithms_enabled()], 655 **kwargs, 656 ) 657 var._call_func(tx, [target_value]) 658 var.set_cleanup_hook(tx) 659 return var 660 661 def __init__(self, target_values, initial_values=None, **kwargs) -> None: 662 super().__init__( 663 target_values=target_values, initial_values=initial_values, **kwargs 664 ) 665 install_guard(self._guards_singleton) 666 667 def enter(self, tx): 668 return variables.ConstantVariable.create(None) 669 670 def _call_func(self, tx: "InstructionTranslator", values): 671 assert len(values) == 1 672 value = values[0] 673 tx.output.create_node( 674 "call_function", torch._C._set_deterministic_algorithms, (value,), {} 675 ), 676 torch._C._set_deterministic_algorithms(value) 677 678 def module_name(self): 679 return "torch" 680 681 def fn_name(self): 682 return "use_deterministic_algorithms" 683 684 685class DisabledSavedTensorsHooksVariable(ContextWrappingVariable): 686 """represents torch.autograd.graph.disable_saved_tensors_hook.""" 687 688 @staticmethod 689 def create(tx: "InstructionTranslator", target_value, **kwargs): 690 var = DisabledSavedTensorsHooksVariable( 691 target_values=[target_value], 692 initial_values=[ 693 torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() 694 ], 695 **kwargs, 696 ) 697 var._call_func(tx, [target_value]) 698 var.set_cleanup_hook(tx) 699 return var 700 701 def __init__(self, target_values, initial_values=None, **kwargs) -> None: 702 super().__init__( 703 target_values=target_values, initial_values=initial_values, **kwargs 704 ) 705 706 def enter(self, tx): 707 return variables.ConstantVariable.create(None) 708 709 def _call_func(self, tx: "InstructionTranslator", values): 710 assert len(values) == 1 711 value = values[0] 712 if value is not None: 713 # Disable `saved_tensors_hooks` with message (`value`) 714 # OR 715 # we are exiting this context and restoring the previous message. 716 tx.output.create_node( 717 "call_function", 718 torch._C._autograd._saved_tensors_hooks_disable, 719 (value,), 720 {}, 721 ) 722 torch._C._autograd._saved_tensors_hooks_disable(value) 723 else: 724 # We are exiting this context and if prev_message was None, we re-enable `saved_tensors_hooks`. 725 tx.output.create_node( 726 "call_function", torch._C._autograd._saved_tensors_hooks_enable, (), {} 727 ) 728 torch._C._autograd._saved_tensors_hooks_enable() 729 730 def module_name(self): 731 return "torch.autograd.graph" 732 733 def fn_name(self): 734 return "disable_saved_tensors_hooks" 735 736 737class AutocastModeVariable(ContextWrappingVariable): 738 @staticmethod 739 def create(func, args, kwargs): 740 assert func in [ 741 torch.amp.autocast_mode.autocast, 742 torch.cuda.amp.autocast, 743 torch.cpu.amp.autocast, 744 ] 745 # device_type : str, 746 # dtype : Optional[_dtype] = None, 747 # enabled : bool = True, 748 # cache_enabled : Optional[bool] = None):cache_enabled 749 bound_args = inspect.signature(func).bind(*args, **kwargs) 750 bound_args.apply_defaults() 751 target_values = [] 752 kwargs.clear() 753 754 for key in ["device_type", "dtype", "enabled", "cache_enabled"]: 755 if key == "device_type" and func in [ 756 torch.cuda.amp.autocast, 757 torch.cpu.amp.autocast, 758 ]: 759 arg = "cuda" if func is torch.cuda.amp.autocast else "cpu" 760 else: 761 arg = bound_args.arguments[key] 762 if isinstance(arg, VariableTracker): 763 target_values.append(arg.as_python_constant()) 764 else: 765 target_values.append(arg) 766 767 var = AutocastModeVariable(target_values, initial_values=None, **kwargs) 768 return var 769 770 def __init__(self, target_values, initial_values=None, **kwargs) -> None: 771 super().__init__( 772 target_values=target_values, initial_values=initial_values, **kwargs 773 ) 774 self.target_values = target_values 775 776 def exit(self, tx: "InstructionTranslator", *args): 777 self.state.cleanup_assert() 778 tx.output.create_node( 779 "call_function", torch.amp._exit_autocast, (self.state.proxy,), {} 780 ) 781 782 def enter(self, tx): 783 ctx = torch.amp._enter_autocast(*self.target_values) 784 self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx)) 785 self.state.proxy = tx.output.create_node( 786 "call_function", torch.amp._enter_autocast, (*self.target_values,), {} 787 ) 788 789 def module_name(self): 790 return "torch.amp.autocast_mode" 791 792 def fn_name(self): 793 return "autocast" 794 795 796class NullContextVariable(ContextWrappingVariable): 797 """ 798 This class represents Python contextlib.nullcontext. 799 It's used as a placeholder for other context managers that Dynamo doesn't 800 support yet, e.g, torch.autograd.profiler.record_function. 801 """ 802 803 def __init__(self, target_values=None, **kwargs) -> None: 804 super().__init__(target_values=target_values, **kwargs) 805 806 def enter(self, tx): 807 return variables.ConstantVariable.create(None) 808 809 def exit(self, tx: "InstructionTranslator", *args): 810 return variables.ConstantVariable.create(None) 811 812 def module_name(self): 813 return "contextlib" 814 815 def fn_name(self): 816 return "nullcontext" 817 818 819class StreamContextVariable(ContextWrappingVariable): 820 @staticmethod 821 def create(tx: "InstructionTranslator", target_value, **kwargs): 822 from .builder import wrap_fx_proxy_cls 823 824 current_stream_method = get_interface_for_device( 825 target_value.device 826 ).current_stream 827 current_stream = wrap_fx_proxy_cls( 828 StreamVariable, 829 tx, 830 tx.output.create_proxy( 831 "call_function", 832 current_stream_method, 833 (None,), 834 {}, 835 ), 836 ) 837 return StreamContextVariable( 838 target_values=[target_value], 839 initial_values=[current_stream], 840 device=target_value.device, 841 **kwargs, 842 ) 843 844 def __init__(self, target_values, device, initial_values=None, **kwargs) -> None: 845 super().__init__( 846 target_values=target_values, initial_values=initial_values, **kwargs 847 ) 848 self.device = device 849 self.set_stream = get_interface_for_device(self.device).set_stream 850 self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id 851 852 def enter(self, tx): 853 # stream generated inside the traced function 854 if self.target_values[0].as_proxy() is not None: 855 tx.output.create_proxy( 856 "call_function", 857 self.set_stream, 858 (self.target_values[0].as_proxy(),), 859 {}, 860 ) 861 # stream passed from outside the traced function 862 else: 863 stream = self.target_values[0].value 864 tx.output.create_proxy( 865 "call_function", 866 self.set_stream_id, 867 (stream.stream_id, stream.device_index, stream.device_type), 868 {}, 869 ) 870 self.set_stream(self.target_values[0].value) 871 self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value)) 872 873 def exit(self, tx: "InstructionTranslator", *args): 874 tx.output.create_proxy( 875 "call_function", 876 self.set_stream, 877 (self.initial_values[0].as_proxy(),), 878 {}, 879 ) 880 self.state.cleanup_assert() 881 882 883class PreserveVersionContextVariable(ContextWrappingVariable): 884 """ 885 Wraps torch.autograd._unsafe_preserve_version_counter 886 """ 887 888 @staticmethod 889 def constructor(tx): 890 return variables.LambdaVariable( 891 lambda tensor: PreserveVersionContextVariable( 892 tensor, 893 tensor.var_getattr(tx, "_version"), 894 ) 895 ) 896 897 def __init__(self, tensor, prev_version, **kwargs) -> None: 898 kwargs.setdefault("target_values", None) 899 super().__init__(**kwargs) 900 self.tensor = tensor 901 self.prev_version = prev_version 902 903 def enter(self, tx): 904 pass 905 906 def exit(self, tx: "InstructionTranslator", *args): 907 from ..tensor_version_op import _unsafe_set_version_counter 908 909 return variables.TorchInGraphFunctionVariable( 910 _unsafe_set_version_counter 911 ).call_function(tx, [self.tensor, self.prev_version], {}) 912 913 def reconstruct(self, codegen): 914 unimplemented( 915 "torch.autograd._unsafe_preserve_version_counter with graph break" 916 ) 917 918 919class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable): 920 _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE) 921 922 @staticmethod 923 def create(tx: "InstructionTranslator", param_group_var, target_value, **kwargs): 924 var = FSDPParamGroupUseTrainingStateVariable( 925 param_group_var=param_group_var, 926 target_values=[target_value], 927 initial_values=[param_group_var.value._training_state], 928 **kwargs, 929 ) 930 return var 931 932 def __init__( 933 self, param_group_var, target_values, initial_values=None, **kwargs 934 ) -> None: 935 super().__init__( 936 target_values=target_values, initial_values=initial_values, **kwargs 937 ) 938 self.param_group_var = param_group_var 939 install_guard(self._guards_singleton) 940 941 def enter(self, tx): 942 self._call_func(tx, self.target_values) 943 return variables.ConstantVariable.create(None) 944 945 def exit(self, tx: "InstructionTranslator", *args): 946 self._call_func(tx, self.initial_values) 947 return variables.ConstantVariable.create(None) 948 949 def call_function( 950 self, 951 tx: "InstructionTranslator", 952 args: "List[VariableTracker]", 953 kwargs: "Dict[str, VariableTracker]", 954 ): 955 self._call_func(tx, self.initial_values) # undo eager initialization 956 return super().call_function(tx, args, kwargs) 957 958 def _call_func(self, tx: "InstructionTranslator", values): 959 assert len(values) == 1 960 value = values[0] 961 if self.param_group_var.value._training_state != value: 962 self.param_group_var.call_method( 963 tx, 964 "__setattr__", 965 ( 966 variables.ConstantVariable.create("_training_state"), 967 variables.EnumVariable(value), 968 ), 969 {}, 970 ) 971 self.param_group_var.value._training_state = value 972 973 def module_name(self): 974 return "torch.distributed._composable.fsdp._fsdp_param_group.FSDPParamGroup" 975 976 def fn_name(self): 977 return "use_training_state" 978 979 980class StreamVariable(VariableTracker): 981 def __init__(self, proxy, value, device, **kwargs) -> None: 982 if proxy is not None and "example_value" in proxy.node.meta: 983 assert proxy.node.meta["example_value"] == value 984 assert ( 985 value.device.type == device.type 986 ), "stream value is not equal to the passed device" 987 super().__init__(**kwargs) 988 self.proxy = proxy 989 self.value = value 990 self.device = device 991 992 def call_method( 993 self, 994 tx, 995 name, 996 args: "List[VariableTracker]", 997 kwargs: "Dict[str, VariableTracker]", 998 ) -> "VariableTracker": 999 assert hasattr(self.value, name), f"no stream method found named {name}" 1000 assert name in [ 1001 "wait_stream", 1002 "synchronize", 1003 "query", 1004 "record_event", 1005 "wait_event", 1006 ], f" unsupported stream method {name}" 1007 1008 from ..utils import proxy_args_kwargs 1009 from .builder import wrap_fx_proxy_cls 1010 1011 if name in ("wait_stream", "synchronize", "wait_event"): 1012 tx.output.create_proxy( 1013 "call_method", name, *proxy_args_kwargs([self] + args, kwargs) 1014 ) 1015 return variables.ConstantVariable(None) 1016 elif name == "query": 1017 return wrap_fx_proxy_cls( 1018 target_cls=variables.ConstantVariable, 1019 tx=tx, 1020 proxy=tx.output.create_proxy( 1021 "call_method", name, *proxy_args_kwargs([self] + args, kwargs) 1022 ), 1023 ) 1024 elif name == "record_event": 1025 return wrap_fx_proxy_cls( 1026 target_cls=EventVariable, 1027 tx=tx, 1028 proxy=tx.output.create_proxy( 1029 "call_method", name, *proxy_args_kwargs([self] + args, kwargs) 1030 ), 1031 ) 1032 else: 1033 unimplemented(self.device + " stream method " + name + " unsupported") 1034 1035 def as_proxy(self): 1036 return self.proxy 1037 1038 def reconstruct(self, codegen): 1039 # If we got here, this stream is fully subsumed by the graph - this means it is 1040 # not an input or global 1041 assert not self.source 1042 # Since we just proved that - for other such structures, like lists and dicts, reconstruction 1043 # is fine and sound according to dynamo principles of treating collectives. However, 1044 # streams are special in that we want to preserve the identity of the stream as the same as in the graph 1045 # Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not 1046 # yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending 1047 # design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there. 1048 prefix = f"_stream_{self.device}" 1049 name = codegen.tx.output.install_global_by_id(prefix, self.value) 1050 codegen.append_output(codegen.create_load_global(name, add=True)) 1051 1052 1053class EventVariable(VariableTracker): 1054 def __init__(self, proxy, value, **kwargs) -> None: 1055 if proxy is not None and "example_value" in proxy.node.meta: 1056 assert proxy.node.meta["example_value"] == value 1057 super().__init__(**kwargs) 1058 self.proxy = proxy 1059 self.value = value 1060 1061 def call_method( 1062 self, 1063 tx, 1064 name, 1065 args: "List[VariableTracker]", 1066 kwargs: "Dict[str, VariableTracker]", 1067 ) -> "VariableTracker": 1068 from ..utils import proxy_args_kwargs 1069 from .builder import wrap_fx_proxy_cls 1070 1071 if name in ("wait", "record", "synchronize"): 1072 tx.output.create_proxy( 1073 "call_method", name, *proxy_args_kwargs([self] + args, kwargs) 1074 ) 1075 return variables.ConstantVariable(None) 1076 elif name == "query": 1077 return wrap_fx_proxy_cls( 1078 target_cls=variables.ConstantVariable, 1079 tx=tx, 1080 proxy=tx.output.create_proxy( 1081 "call_method", name, *proxy_args_kwargs([self] + args, kwargs) 1082 ), 1083 ) 1084 else: 1085 unimplemented(f"event method {name} unsupported") 1086 1087 def as_proxy(self): 1088 return self.proxy 1089 1090 def reconstruct(self, codegen): 1091 # If we got here, this event is fully subsumed by the graph - this means it is 1092 # not an input or global 1093 assert not self.source 1094 # Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there. 1095 prefix = "_event" 1096 name = codegen.tx.output.install_global_by_id(prefix, self.value) 1097 codegen.append_output(codegen.create_load_global(name, add=True)) 1098 1099 1100class WithExitFunctionVariable(VariableTracker): 1101 _nonvar_fields = { 1102 "target", 1103 *VariableTracker._nonvar_fields, 1104 } 1105 1106 def __init__( 1107 self, 1108 ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable], 1109 target, 1110 **kwargs, 1111 ) -> None: 1112 super().__init__(**kwargs) 1113 assert isinstance( 1114 ctx, (ContextWrappingVariable, GenericContextWrappingVariable) 1115 ) 1116 self.ctx = ctx 1117 self.target = target 1118 1119 def call_function( 1120 self, 1121 tx: "InstructionTranslator", 1122 args: "List[VariableTracker]", 1123 kwargs: "Dict[str, VariableTracker]", 1124 ) -> "VariableTracker": 1125 assert not kwargs 1126 return self.ctx.exit(tx, *args) 1127 1128 def reconstruct(self, codegen): 1129 # Note here we reconstruct the context manager rather than the 1130 # exit function. The handler generated by BlockStackEntry 1131 # will re-enter the context in the resume function. 1132 self.ctx.reconstruct_type(codegen) 1133 if codegen.tx.output.partial_convert: 1134 if sys.version_info >= (3, 11): 1135 codegen.append_output(create_instruction("PUSH_NULL")) 1136 if sys.version_info < (3, 13): 1137 codegen.append_output(create_instruction("SWAP", arg=2)) 1138 codegen.extend_output( 1139 [codegen.create_load_const(val) for val in self.ctx.target_values] 1140 ) 1141 codegen.extend_output( 1142 create_call_function(len(self.ctx.target_values), False) 1143 ) 1144 codegen.append_output(create_setup_with(self.target)) 1145 codegen.append_output(create_instruction("POP_TOP")) 1146