1# mypy: allow-untyped-defs 2import torch 3from torch._prims import _make_prim, RETURN_TYPE 4from torch._subclasses import FakeTensorMode 5from torch._subclasses.functional_tensor import FunctionalTensorMode 6 7 8_tensor_version = _make_prim( 9 schema="_tensor_version(Tensor self) -> SymInt", 10 return_type=RETURN_TYPE.NEW, 11 meta=torch.ops.aten._version.default, 12 impl_aten=torch.ops.aten._version.default, 13 doc="Tracable unbacked SymInt version of torch.Tensor._version", 14) 15 16 17@_tensor_version.py_impl(FakeTensorMode) 18def _tensor_version_fake(fake_mode, self_tensor): 19 """ 20 The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the 21 `._version` into an unbacked SymInt so that we don't need to specialize on the `._version` 22 of input tensors to the graph. 23 """ 24 return fake_mode.shape_env.create_unbacked_symint() 25 26 27_unsafe_set_version_counter = _make_prim( 28 schema="_unsafe_set_version_counter(Tensor self, SymInt version) -> ()", 29 return_type=RETURN_TYPE.NEW, 30 meta=lambda self, version: None, 31 impl_aten=torch._C._autograd._unsafe_set_version_counter, 32 doc="Tracable+SymInt version of torch._C._autograd._unsafe_set_version_counter", 33) 34torch.fx.node.has_side_effect(_unsafe_set_version_counter) 35 36 37""" 38When we functionalize _tensor_version + _unsafe_set_version_counter, 39the ops disappear from the traced graph. We run them eagerly on the 40fake tensors used for tracing, in order to get past asserts that would 41fail in autograd. 42 43Why is this ok? 441) Versions on functional tensors don't make any sense since you can't mutate a functional tensor. 452) The whole point of version munging is to trick autograd into doing what we want, and after 46 AotAtuograd there is no longer any need for these ops. 47 48Note this is similar to how no_grad is handled. 49""" 50 51 52@_tensor_version.py_impl(FunctionalTensorMode) 53def _tensor_version_functional(mode, self): 54 return self._version 55 56 57@_unsafe_set_version_counter.py_impl(FunctionalTensorMode) 58def _unsafe_set_version_counter_functional(ctx, self, version): 59 torch._C._autograd._unsafe_set_version_counter(self, version) 60