xref: /aosp_15_r20/external/pytorch/torch/_dynamo/tensor_version_op.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3from torch._prims import _make_prim, RETURN_TYPE
4from torch._subclasses import FakeTensorMode
5from torch._subclasses.functional_tensor import FunctionalTensorMode
6
7
8_tensor_version = _make_prim(
9    schema="_tensor_version(Tensor self) -> SymInt",
10    return_type=RETURN_TYPE.NEW,
11    meta=torch.ops.aten._version.default,
12    impl_aten=torch.ops.aten._version.default,
13    doc="Tracable unbacked SymInt version of torch.Tensor._version",
14)
15
16
17@_tensor_version.py_impl(FakeTensorMode)
18def _tensor_version_fake(fake_mode, self_tensor):
19    """
20    The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the
21    `._version` into an unbacked SymInt so that we don't need to specialize on the `._version`
22    of input tensors to the graph.
23    """
24    return fake_mode.shape_env.create_unbacked_symint()
25
26
27_unsafe_set_version_counter = _make_prim(
28    schema="_unsafe_set_version_counter(Tensor self, SymInt version) -> ()",
29    return_type=RETURN_TYPE.NEW,
30    meta=lambda self, version: None,
31    impl_aten=torch._C._autograd._unsafe_set_version_counter,
32    doc="Tracable+SymInt version of torch._C._autograd._unsafe_set_version_counter",
33)
34torch.fx.node.has_side_effect(_unsafe_set_version_counter)
35
36
37"""
38When we functionalize _tensor_version + _unsafe_set_version_counter,
39the ops disappear from the traced graph.  We run them eagerly on the
40fake tensors used for tracing, in order to get past asserts that would
41fail in autograd.
42
43Why is this ok?
441) Versions on functional tensors don't make any sense since you can't mutate a functional tensor.
452) The whole point of version munging is to trick autograd into doing what we want, and after
46   AotAtuograd there is no longer any need for these ops.
47
48Note this is similar to how no_grad is handled.
49"""
50
51
52@_tensor_version.py_impl(FunctionalTensorMode)
53def _tensor_version_functional(mode, self):
54    return self._version
55
56
57@_unsafe_set_version_counter.py_impl(FunctionalTensorMode)
58def _unsafe_set_version_counter_functional(ctx, self, version):
59    torch._C._autograd._unsafe_set_version_counter(self, version)
60