1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import weakref 5from weakref import ref 6from _weakrefset import _IterationGuard # type: ignore[attr-defined] 7from collections.abc import MutableMapping, Mapping 8from torch import Tensor 9import collections.abc as _collections_abc 10 11 12WeakRef = ref 13 14 15__all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary'] 16 17 18# This file defines a variant of WeakKeyDictionary that overrides the hashing 19# behavior of the key to use object identity, rather than the builtin 20# __eq__/__hash__ functions. This is useful for Tensor weak keys, as their 21# __eq__ implementation return a Tensor (elementwise equality), which means 22# you can't use them directly with the WeakKeyDictionary in standard library. 23# 24# Our implementation strategy is to create a wrapper weak key object, which we 25# use as a key in a stock Python dictionary. This is similar to how weakref 26# implements WeakKeyDictionary, but instead of using weakref.ref as the 27# wrapper, we use a custom wrapper that has different __eq__ and __hash__ 28# behavior. Note that we subsequently store this weak key directly in an 29# ORDINARY dictionary, since the newly constructed WeakIdKey's only use would 30# be a dictionary so it would have no strong references. Ensuring that 31# only live WeakIdKeys are in the map is handled by putting finalizers on the 32# original key object. 33 34 35# It is simpler to implement this with composition, but if we want to 36# directly reuse the callback mechanism on weakref, we need the weakref 37# and the key to be exactly the same object. Reusing the callback mechanism 38# minimizes the divergence between our implementation and Lib/weakref.py 39# 40# NB: Prefer using this when working with weakrefs of Tensors; e.g., do 41# WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of 42# easy to get wrong cases transparently for you. 43class WeakIdRef(weakref.ref): 44 __slots__ = ['_id'] 45 46 def __init__(self, key, callback=None): 47 # Unlike stock weakref, which preserves hash semantics of the 48 # original object but lazily defers hash calls until the first 49 # time the user attempts to hash the weakref, we can eagerly 50 # cache the id of the key as we know this is definitely the hash 51 # method 52 self._id = id(key) 53 super().__init__(key, callback) # type: ignore[call-arg] 54 55 def __call__(self): 56 r = super().__call__() 57 # Special logic for Tensor PyObject resurrection 58 if hasattr(r, '_fix_weakref'): 59 r._fix_weakref() # type: ignore[union-attr] 60 return r 61 62 def __hash__(self): 63 return self._id 64 65 def __eq__(self, other): 66 # An attractive but wrong alternate implementation is to only test if 67 # the stored _ids match. This can lead to an ABA problem if you have: 68 # 69 # a1 = A() 70 # w1 = WeakIdRef(a1) 71 # del a1 72 # a2 = A() # suppose it gets the same ID as a1 73 # w2 = WeakIdRef(a2) 74 # print(w1 == w2) 75 # 76 # This should be False, as a1 and a2 are unrelated (and a1 is 77 # dead anyway) 78 a = self() 79 b = other() 80 if a is not None and b is not None: 81 return a is b 82 return self is other 83 84# This is the same as WeakIdRef but equality is checked using hash() rather than id. 85# This will be equivalent to the one above except for classes where hash is not their id. 86class _WeakHashRef(weakref.ref): 87 __slots__ = ['_id'] 88 89 def __init__(self, key, callback=None): 90 # Unlike stock weakref, which preserves hash semantics of the 91 # original object but lazily defers hash calls until the first 92 # time the user attempts to hash the weakref, we can eagerly 93 # cache the id of the key as we know this is definitely the hash 94 # method 95 self._id = hash(key) 96 super().__init__(key, callback) # type: ignore[call-arg] 97 98 def __call__(self): 99 r = super().__call__() 100 # Special logic for Tensor PyObject resurrection 101 if hasattr(r, '_fix_weakref'): 102 r._fix_weakref() # type: ignore[union-attr] 103 return r 104 105 def __hash__(self): 106 return self._id 107 108 def __eq__(self, other): 109 # Use hash equality to determine ref equality. 110 # ScriptObject implements __hash__ to return the wrapped IValue's id, so 111 # this is equivalent to doing an identity comparison. 112 a = self() 113 b = other() 114 if a is not None and b is not None: 115 return hash(a) == hash(b) 116 return self is other 117 118# This is directly adapted from cpython/Lib/weakref.py 119class WeakIdKeyDictionary(MutableMapping): 120 def __init__(self, dict=None, ref_type=WeakIdRef): # CHANGED 121 self.data = {} 122 123 self.ref_type = ref_type # CHANGED 124 125 def remove(k, selfref=ref(self)): 126 self = selfref() 127 if self is not None: 128 if self._iterating: 129 self._pending_removals.append(k) 130 else: 131 try: 132 del self.data[k] 133 except KeyError: 134 pass 135 self._remove = remove 136 # A list of dead weakrefs (keys to be removed) 137 self._pending_removals = [] 138 self._iterating = set() 139 self._dirty_len = False 140 if dict is not None: 141 self.update(dict) 142 143 def _commit_removals(self): 144 # NOTE: We don't need to call this method before mutating the dict, 145 # because a dead weakref never compares equal to a live weakref, 146 # even if they happened to refer to equal objects. 147 # However, it means keys may already have been removed. 148 pop = self._pending_removals.pop 149 d = self.data 150 while True: 151 try: 152 key = pop() 153 except IndexError: 154 return 155 156 try: 157 del d[key] 158 except KeyError: 159 pass 160 161 def _scrub_removals(self): 162 d = self.data 163 self._pending_removals = [k for k in self._pending_removals if k in d] 164 self._dirty_len = False 165 166 def __delitem__(self, key): 167 self._dirty_len = True 168 del self.data[self.ref_type(key)] # CHANGED 169 170 def __getitem__(self, key): 171 return self.data[self.ref_type(key)] # CHANGED 172 173 def __len__(self): 174 if self._dirty_len and self._pending_removals: 175 # self._pending_removals may still contain keys which were 176 # explicitly removed, we have to scrub them (see issue #21173). 177 self._scrub_removals() 178 return len(self.data) - len(self._pending_removals) 179 180 def __repr__(self): 181 return f"<{self.__class__.__name__} at {id(self):#x}>" 182 183 def __setitem__(self, key, value): 184 self.data[self.ref_type(key, self._remove)] = value # CHANGED 185 186 def copy(self): 187 new = WeakIdKeyDictionary() 188 with _IterationGuard(self): 189 for key, value in self.data.items(): 190 o = key() 191 if o is not None: 192 new[o] = value 193 return new 194 195 __copy__ = copy 196 197 def __deepcopy__(self, memo): 198 from copy import deepcopy 199 new = self.__class__() 200 with _IterationGuard(self): 201 for key, value in self.data.items(): 202 o = key() 203 if o is not None: 204 new[o] = deepcopy(value, memo) 205 return new 206 207 def get(self, key, default=None): 208 return self.data.get(self.ref_type(key), default) # CHANGED 209 210 def __contains__(self, key): 211 try: 212 wr = self.ref_type(key) # CHANGED 213 except TypeError: 214 return False 215 return wr in self.data 216 217 def items(self): 218 with _IterationGuard(self): 219 for wr, value in self.data.items(): 220 key = wr() 221 if key is not None: 222 yield key, value 223 224 def keys(self): 225 with _IterationGuard(self): 226 for wr in self.data: 227 obj = wr() 228 if obj is not None: 229 yield obj 230 231 __iter__ = keys 232 233 def values(self): 234 with _IterationGuard(self): 235 for wr, value in self.data.items(): 236 if wr() is not None: 237 yield value 238 239 def keyrefs(self): 240 """Return a list of weak references to the keys. 241 242 The references are not guaranteed to be 'live' at the time 243 they are used, so the result of calling the references needs 244 to be checked before being used. This can be used to avoid 245 creating references that will cause the garbage collector to 246 keep the keys around longer than needed. 247 248 """ 249 return list(self.data) 250 251 def popitem(self): 252 self._dirty_len = True 253 while True: 254 key, value = self.data.popitem() 255 o = key() 256 if o is not None: 257 return o, value 258 259 def pop(self, key, *args): 260 self._dirty_len = True 261 return self.data.pop(self.ref_type(key), *args) # CHANGED 262 263 def setdefault(self, key, default=None): 264 return self.data.setdefault(self.ref_type(key, self._remove), default) # CHANGED 265 266 def update(self, dict=None, **kwargs): 267 d = self.data 268 if dict is not None: 269 if not hasattr(dict, "items"): 270 dict = type({})(dict) 271 for key, value in dict.items(): 272 d[self.ref_type(key, self._remove)] = value # CHANGED 273 if len(kwargs): 274 self.update(kwargs) 275 276 def __ior__(self, other): 277 self.update(other) 278 return self 279 280 def __or__(self, other): 281 if isinstance(other, _collections_abc.Mapping): 282 c = self.copy() 283 c.update(other) 284 return c 285 return NotImplemented 286 287 def __ror__(self, other): 288 if isinstance(other, _collections_abc.Mapping): 289 c = self.__class__() 290 c.update(other) 291 c.update(self) 292 return c 293 return NotImplemented 294 295 # Default Mapping equality will tests keys for equality, but 296 # we want to test ids for equality 297 def __eq__(self, other): 298 if not isinstance(other, Mapping): 299 return NotImplemented 300 return {id(k): v for k, v in self.items()} == {id(k): v for k, v in other.items()} 301 302# Convenience alias 303WeakTensorKeyDictionary = WeakIdKeyDictionary 304 305 306class TensorWeakRef: 307 """Wrapper around a weak ref of a Tensor that handles the _fix_weakref() call required when unwrapping a Tensor weakref.""" 308 309 ref: WeakRef[Tensor] 310 311 def __init__(self, tensor: Tensor): 312 assert isinstance(tensor, Tensor) 313 self.ref = weakref.ref(tensor) 314 315 def __call__(self): 316 out = self.ref() 317 if out is None: 318 return out 319 assert isinstance(out, Tensor) 320 # TODO, add _fix_weakref type binding 321 out._fix_weakref() # type: ignore[attr-defined] 322 return out 323