xref: /aosp_15_r20/external/pytorch/torch/_subclasses/_fake_tensor_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from dataclasses import dataclass
4from typing import Dict, List, Optional, Type, TYPE_CHECKING, Union
5
6import torch
7from torch import SymInt
8from torch.fx.experimental.sym_node import SymNode
9from torch.types import py_sym_types, PySymType
10from torch.utils._backport_slots import dataclass_slots
11
12
13if TYPE_CHECKING:
14    import sympy
15
16    from torch.fx.experimental.symbolic_shapes import ShapeEnv
17
18    from .fake_tensor import _DispatchCacheKey, _MetadataIntLike
19
20
21@dataclass_slots
22@dataclass(frozen=True)
23class _DeconstructedSymNode:
24    """
25    Represents a SymNode without the associated ShapeEnv
26    """
27
28    # n.b. keep the same protocol as SymNode
29    _expr: sympy.Expr
30    pytype: type
31    _hint: Optional[Union[int, float, bool]]
32    constant: Optional[Union[int, float, bool]]
33    fx_node: torch.fx.Node
34
35    @staticmethod
36    def from_node(node: SymNode) -> _DeconstructedSymNode:
37        return _DeconstructedSymNode(
38            node._expr, node.pytype, node._hint, node.constant, node.fx_node
39        )
40
41    def extract(self, shape_env: ShapeEnv) -> SymNode:
42        return SymNode(
43            self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
44        )
45
46    def __str__(self) -> str:
47        return str(self._expr)
48
49    def __repr__(self) -> str:
50        return f"_DeconstructedSymNode{{{self._expr!r}, {self.pytype!r}, {self._hint!r}, {self.constant!r}, {self.fx_node!r}}}"
51
52    def __eq__(self, other: object) -> bool:
53        raise NotImplementedError
54
55    def __hash__(self) -> int:
56        raise NotImplementedError
57
58    # _value_eq to match SymNode
59    def _value_eq(self, other: object) -> bool:
60        if isinstance(other, (SymNode, _DeconstructedSymNode)):
61            return (
62                self._expr == other._expr
63                and self.pytype == other.pytype
64                and self._hint == other._hint
65                and self.constant == other.constant
66                and self.fx_node == other.fx_node
67            )
68        else:
69            return False
70
71    # _value_hash to match SymNode
72    def _value_hash(self) -> int:
73        return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
74
75
76@dataclass_slots
77@dataclass(frozen=True)
78class _DeconstructedSymType:
79    """
80    Represents a SymInt, SymFloat, SymBool without the associated ShapeEnv
81    """
82
83    ty: Type[PySymType]
84    node: _DeconstructedSymNode
85
86    @staticmethod
87    def from_sym_type(value: PySymType) -> _DeconstructedSymType:
88        return _DeconstructedSymType(type(value), value.node)
89
90    def extract(self, shape_env: ShapeEnv) -> PySymType:
91        return self.ty(self.node.extract(shape_env))
92
93    def __str__(self) -> str:
94        return f"{self.ty}({self.node})"
95
96    def __repr__(self) -> str:
97        return f"_DeconstructedSymType({self.ty}, {self.node!r})"
98
99    def __eq__(self, other: object) -> bool:
100        return NotImplemented
101
102    def __hash__(self) -> int:
103        return NotImplemented
104
105
106@dataclass_slots
107@dataclass(frozen=True)
108class _InputBackref:
109    value: int
110
111
112@dataclass_slots
113@dataclass
114class _PySymInputStub:
115    """
116    Represents a SymInt in the cached key. Needed because SymInt doesn't
117    support __eq__ or __hash__ directly.
118    """
119
120    # value can be:
121    #   PySymType: This is the 'normal' SymInt value, wrapped so we can use
122    #              hash/eq as value hash/eq (normally SymInt does object
123    #              hash/eq).
124    #   _DeconstructedSymType: This is used when storing the _PySymInputStub in
125    #                          the cache to avoid cyclic ShapeEnv references.
126    #   _InputBackref: This is a back-reference to a previous _PySymInputStub in
127    #                  the key.
128    value: Union[PySymType, _DeconstructedSymType, _InputBackref]
129
130    def __init__(
131        self, value: Union[PySymType, _DeconstructedSymType, _InputBackref]
132    ) -> None:
133        # For inputs (values in the `key`) we need to keep the PySymType intact
134        # - this way if we need to reuse it as an output we can properly copy
135        # the original value.
136        self.value = value
137
138    def strip_shape_env(self) -> None:
139        if isinstance(self.value, py_sym_types):
140            self.value = _DeconstructedSymType.from_sym_type(self.value)
141
142    def extract(self, shape_env: ShapeEnv) -> PySymType:
143        if isinstance(self.value, _DeconstructedSymType):
144            return self.value.extract(shape_env)
145        else:
146            # We should never see an _InputBackref here - anyone extracting a
147            # value should be pulling from the original entry (the one this
148            # backref points at).
149            assert not isinstance(self.value, _InputBackref)
150            return self.value
151
152    def __str__(self) -> str:
153        return str(self.value)
154
155    def __repr__(self) -> str:
156        return f"_PySymInputStub({self.value!r})"
157
158    def __eq__(self, other: object) -> bool:
159        if not isinstance(other, _PySymInputStub):
160            return False
161        elif isinstance(self.value, _InputBackref) or isinstance(
162            other.value, _InputBackref
163        ):
164            return self.value == other.value
165        else:
166            return self.value.node._value_eq(other.value.node)
167
168    def __hash__(self) -> int:
169        if isinstance(self.value, _InputBackref):
170            return hash(self.value)
171        else:
172            return self.value.node._value_hash()
173
174
175@dataclass_slots
176@dataclass
177class _SymIntOutputStub:
178    """
179    Represents a SymInt in the cached output.
180    """
181
182    # This is either an `int` which represents the index in the key to copy the
183    # SymNode from or it's the deconstructed SymNode itself.
184    value: Union[int, _DeconstructedSymNode]
185
186    def __init__(self, value: SymInt, key_path: Optional[int]) -> None:
187        if key_path is None:
188            self.value = _DeconstructedSymNode.from_node(value.node)
189        else:
190            self.value = key_path
191
192    def extract(self, key: _DispatchCacheKey, shape_env: ShapeEnv) -> SymInt:
193        if isinstance(self.value, _DeconstructedSymNode):
194            return SymInt(self.value.extract(shape_env))
195        else:
196            src = key.key[self.value]
197            assert isinstance(src, _PySymInputStub) and isinstance(src.value, SymInt)
198            return src.value
199
200    def __repr__(self) -> str:
201        return f"_SymIntOutputStub({self.value!r})"
202
203    def __eq__(self, other: object) -> bool:
204        raise NotImplementedError
205
206    def __hash__(self) -> int:
207        raise NotImplementedError
208
209
210@dataclass_slots
211@dataclass
212class _CacheKeyState:
213    """
214    State used while building our cache key.
215    """
216
217    # We track the SymNodes so when we get the output we can see if it exactly
218    # matches one of the inputs so we can uncache it properly.
219    sym_node_lookup: Dict[int, int]  # id(SymNode) -> index
220
221    # There are cases where we're asked to perform an op when we have no
222    # ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a
223    # ShapeEnv. So as we scan if we see a SymNode (with a ShapeEnv) we record it
224    # here.
225    shape_env: Optional[ShapeEnv]
226
227    def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None:
228        self.sym_node_lookup = {}
229        self.shape_env = shape_env
230
231    def cache_on_shape_env(self) -> bool:
232        """
233        Returns true if the CacheKey needs to be cached on the ShapeEnv
234        rather than the global cache.
235
236        If our inputs contain a SymNode then we can't cache this operation on
237        the global cache because the cached output will implicitly depend on
238        guard values which might not be true on some other ShapeEnv. So unless
239        we're also going to cache the guards we need to cache this operation on
240        the ShapeEnv instead of globally.
241        """
242        return bool(self.sym_node_lookup)
243
244    def convert_sym_int(self, result: List[object], arg: SymInt) -> None:
245        node_id = id(arg.node)
246        if node_id in self.sym_node_lookup:
247            result.append(_InputBackref(self.sym_node_lookup[node_id]))
248        else:
249            self.sym_node_lookup[node_id] = len(result)
250            if self.shape_env is None:
251                self.shape_env = arg.node.shape_env
252            result.append(_PySymInputStub(arg))
253
254    def convert_output(self, arg: _MetadataIntLike) -> _MetadataIntLike:
255        if isinstance(arg, SymInt):
256            return _SymIntOutputStub(arg, self.sym_node_lookup.get(id(arg.node), None))
257        else:
258            return arg
259