1# mypy: allow-untyped-defs 2import torch 3from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode 4from torch.overrides import TorchFunctionMode 5 6 7class AutogradStateOpsFailSafeguard(TorchFunctionMode): 8 """ 9 Detect grad state ops during exporting the graph and fail the process by 10 raising an error, to avoid unexpected behavior. Those grad mode ops could be: 11 `torch.no_grad` 12 `torch.enable_grad` 13 `torch.set_grad_enabled` 14 15 Export with predispatch mode is exempted. 16 """ 17 18 def __torch_function__(self, func, types, args=(), kwargs=None): 19 kwargs = kwargs or {} 20 unsupported_grad_mode_ops = [ 21 torch._C._set_grad_enabled, 22 ] 23 # It's only enabled while tracing, by confirming the torch dispatch mode is 24 # any active PROXY. This is to allow the autograd ops out of tracing. 25 current_state = torch._C.is_grad_enabled() 26 if func in unsupported_grad_mode_ops: 27 assert len(args) == 1 28 changed_state = args[0] 29 mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) 30 # Intend to check if it's not the pre_dispatch mode. It's allowed to use 31 # autograd ops in pre_dispatch mode, e.g. `torch.no_grad` 32 if ( 33 mode 34 and isinstance(mode, ProxyTorchDispatchMode) 35 and not mode.pre_dispatch 36 and changed_state != current_state 37 ): 38 raise RuntimeError( 39 f"Encountered autograd state manager op {func} trying to change global autograd state " 40 "while exporting. This is unsafe because we don't capture this op in torch.export " 41 "today, hence we can't reflect the user intention soundly. You can fix this by " 42 "adding a torch.no_grad() context around the export call." 43 ) 44 return func(*args, **kwargs) 45