1from __future__ import annotations 2 3from dataclasses import dataclass 4from typing import Dict, List, Optional, Type, TYPE_CHECKING, Union 5 6import torch 7from torch import SymInt 8from torch.fx.experimental.sym_node import SymNode 9from torch.types import py_sym_types, PySymType 10from torch.utils._backport_slots import dataclass_slots 11 12 13if TYPE_CHECKING: 14 import sympy 15 16 from torch.fx.experimental.symbolic_shapes import ShapeEnv 17 18 from .fake_tensor import _DispatchCacheKey, _MetadataIntLike 19 20 21@dataclass_slots 22@dataclass(frozen=True) 23class _DeconstructedSymNode: 24 """ 25 Represents a SymNode without the associated ShapeEnv 26 """ 27 28 # n.b. keep the same protocol as SymNode 29 _expr: sympy.Expr 30 pytype: type 31 _hint: Optional[Union[int, float, bool]] 32 constant: Optional[Union[int, float, bool]] 33 fx_node: torch.fx.Node 34 35 @staticmethod 36 def from_node(node: SymNode) -> _DeconstructedSymNode: 37 return _DeconstructedSymNode( 38 node._expr, node.pytype, node._hint, node.constant, node.fx_node 39 ) 40 41 def extract(self, shape_env: ShapeEnv) -> SymNode: 42 return SymNode( 43 self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node 44 ) 45 46 def __str__(self) -> str: 47 return str(self._expr) 48 49 def __repr__(self) -> str: 50 return f"_DeconstructedSymNode{{{self._expr!r}, {self.pytype!r}, {self._hint!r}, {self.constant!r}, {self.fx_node!r}}}" 51 52 def __eq__(self, other: object) -> bool: 53 raise NotImplementedError 54 55 def __hash__(self) -> int: 56 raise NotImplementedError 57 58 # _value_eq to match SymNode 59 def _value_eq(self, other: object) -> bool: 60 if isinstance(other, (SymNode, _DeconstructedSymNode)): 61 return ( 62 self._expr == other._expr 63 and self.pytype == other.pytype 64 and self._hint == other._hint 65 and self.constant == other.constant 66 and self.fx_node == other.fx_node 67 ) 68 else: 69 return False 70 71 # _value_hash to match SymNode 72 def _value_hash(self) -> int: 73 return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node)) 74 75 76@dataclass_slots 77@dataclass(frozen=True) 78class _DeconstructedSymType: 79 """ 80 Represents a SymInt, SymFloat, SymBool without the associated ShapeEnv 81 """ 82 83 ty: Type[PySymType] 84 node: _DeconstructedSymNode 85 86 @staticmethod 87 def from_sym_type(value: PySymType) -> _DeconstructedSymType: 88 return _DeconstructedSymType(type(value), value.node) 89 90 def extract(self, shape_env: ShapeEnv) -> PySymType: 91 return self.ty(self.node.extract(shape_env)) 92 93 def __str__(self) -> str: 94 return f"{self.ty}({self.node})" 95 96 def __repr__(self) -> str: 97 return f"_DeconstructedSymType({self.ty}, {self.node!r})" 98 99 def __eq__(self, other: object) -> bool: 100 return NotImplemented 101 102 def __hash__(self) -> int: 103 return NotImplemented 104 105 106@dataclass_slots 107@dataclass(frozen=True) 108class _InputBackref: 109 value: int 110 111 112@dataclass_slots 113@dataclass 114class _PySymInputStub: 115 """ 116 Represents a SymInt in the cached key. Needed because SymInt doesn't 117 support __eq__ or __hash__ directly. 118 """ 119 120 # value can be: 121 # PySymType: This is the 'normal' SymInt value, wrapped so we can use 122 # hash/eq as value hash/eq (normally SymInt does object 123 # hash/eq). 124 # _DeconstructedSymType: This is used when storing the _PySymInputStub in 125 # the cache to avoid cyclic ShapeEnv references. 126 # _InputBackref: This is a back-reference to a previous _PySymInputStub in 127 # the key. 128 value: Union[PySymType, _DeconstructedSymType, _InputBackref] 129 130 def __init__( 131 self, value: Union[PySymType, _DeconstructedSymType, _InputBackref] 132 ) -> None: 133 # For inputs (values in the `key`) we need to keep the PySymType intact 134 # - this way if we need to reuse it as an output we can properly copy 135 # the original value. 136 self.value = value 137 138 def strip_shape_env(self) -> None: 139 if isinstance(self.value, py_sym_types): 140 self.value = _DeconstructedSymType.from_sym_type(self.value) 141 142 def extract(self, shape_env: ShapeEnv) -> PySymType: 143 if isinstance(self.value, _DeconstructedSymType): 144 return self.value.extract(shape_env) 145 else: 146 # We should never see an _InputBackref here - anyone extracting a 147 # value should be pulling from the original entry (the one this 148 # backref points at). 149 assert not isinstance(self.value, _InputBackref) 150 return self.value 151 152 def __str__(self) -> str: 153 return str(self.value) 154 155 def __repr__(self) -> str: 156 return f"_PySymInputStub({self.value!r})" 157 158 def __eq__(self, other: object) -> bool: 159 if not isinstance(other, _PySymInputStub): 160 return False 161 elif isinstance(self.value, _InputBackref) or isinstance( 162 other.value, _InputBackref 163 ): 164 return self.value == other.value 165 else: 166 return self.value.node._value_eq(other.value.node) 167 168 def __hash__(self) -> int: 169 if isinstance(self.value, _InputBackref): 170 return hash(self.value) 171 else: 172 return self.value.node._value_hash() 173 174 175@dataclass_slots 176@dataclass 177class _SymIntOutputStub: 178 """ 179 Represents a SymInt in the cached output. 180 """ 181 182 # This is either an `int` which represents the index in the key to copy the 183 # SymNode from or it's the deconstructed SymNode itself. 184 value: Union[int, _DeconstructedSymNode] 185 186 def __init__(self, value: SymInt, key_path: Optional[int]) -> None: 187 if key_path is None: 188 self.value = _DeconstructedSymNode.from_node(value.node) 189 else: 190 self.value = key_path 191 192 def extract(self, key: _DispatchCacheKey, shape_env: ShapeEnv) -> SymInt: 193 if isinstance(self.value, _DeconstructedSymNode): 194 return SymInt(self.value.extract(shape_env)) 195 else: 196 src = key.key[self.value] 197 assert isinstance(src, _PySymInputStub) and isinstance(src.value, SymInt) 198 return src.value 199 200 def __repr__(self) -> str: 201 return f"_SymIntOutputStub({self.value!r})" 202 203 def __eq__(self, other: object) -> bool: 204 raise NotImplementedError 205 206 def __hash__(self) -> int: 207 raise NotImplementedError 208 209 210@dataclass_slots 211@dataclass 212class _CacheKeyState: 213 """ 214 State used while building our cache key. 215 """ 216 217 # We track the SymNodes so when we get the output we can see if it exactly 218 # matches one of the inputs so we can uncache it properly. 219 sym_node_lookup: Dict[int, int] # id(SymNode) -> index 220 221 # There are cases where we're asked to perform an op when we have no 222 # ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a 223 # ShapeEnv. So as we scan if we see a SymNode (with a ShapeEnv) we record it 224 # here. 225 shape_env: Optional[ShapeEnv] 226 227 def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None: 228 self.sym_node_lookup = {} 229 self.shape_env = shape_env 230 231 def cache_on_shape_env(self) -> bool: 232 """ 233 Returns true if the CacheKey needs to be cached on the ShapeEnv 234 rather than the global cache. 235 236 If our inputs contain a SymNode then we can't cache this operation on 237 the global cache because the cached output will implicitly depend on 238 guard values which might not be true on some other ShapeEnv. So unless 239 we're also going to cache the guards we need to cache this operation on 240 the ShapeEnv instead of globally. 241 """ 242 return bool(self.sym_node_lookup) 243 244 def convert_sym_int(self, result: List[object], arg: SymInt) -> None: 245 node_id = id(arg.node) 246 if node_id in self.sym_node_lookup: 247 result.append(_InputBackref(self.sym_node_lookup[node_id])) 248 else: 249 self.sym_node_lookup[node_id] = len(result) 250 if self.shape_env is None: 251 self.shape_env = arg.node.shape_env 252 result.append(_PySymInputStub(arg)) 253 254 def convert_output(self, arg: _MetadataIntLike) -> _MetadataIntLike: 255 if isinstance(arg, SymInt): 256 return _SymIntOutputStub(arg, self.sym_node_lookup.get(id(arg.node), None)) 257 else: 258 return arg 259