xref: /aosp_15_r20/external/pytorch/torch/utils/weak.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import weakref
5from weakref import ref
6from _weakrefset import _IterationGuard  # type: ignore[attr-defined]
7from collections.abc import MutableMapping, Mapping
8from torch import Tensor
9import collections.abc as _collections_abc
10
11
12WeakRef = ref
13
14
15__all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
16
17
18# This file defines a variant of WeakKeyDictionary that overrides the hashing
19# behavior of the key to use object identity, rather than the builtin
20# __eq__/__hash__ functions.  This is useful for Tensor weak keys, as their
21# __eq__ implementation return a Tensor (elementwise equality), which means
22# you can't use them directly with the WeakKeyDictionary in standard library.
23#
24# Our implementation strategy is to create a wrapper weak key object, which we
25# use as a key in a stock Python dictionary.  This is similar to how weakref
26# implements WeakKeyDictionary, but instead of using weakref.ref as the
27# wrapper, we use a custom wrapper that has different __eq__ and __hash__
28# behavior.  Note that we subsequently store this weak key directly in an
29# ORDINARY dictionary, since the newly constructed WeakIdKey's only use would
30# be a dictionary so it would have no strong references.  Ensuring that
31# only live WeakIdKeys are in the map is handled by putting finalizers on the
32# original key object.
33
34
35# It is simpler to implement this with composition, but if we want to
36# directly reuse the callback mechanism on weakref, we need the weakref
37# and the key to be exactly the same object.  Reusing the callback mechanism
38# minimizes the divergence between our implementation and Lib/weakref.py
39#
40# NB: Prefer using this when working with weakrefs of Tensors; e.g., do
41# WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of
42# easy to get wrong cases transparently for you.
43class WeakIdRef(weakref.ref):
44    __slots__ = ['_id']
45
46    def __init__(self, key, callback=None):
47        # Unlike stock weakref, which preserves hash semantics of the
48        # original object but lazily defers hash calls until the first
49        # time the user attempts to hash the weakref, we can eagerly
50        # cache the id of the key as we know this is definitely the hash
51        # method
52        self._id = id(key)
53        super().__init__(key, callback)  # type: ignore[call-arg]
54
55    def __call__(self):
56        r = super().__call__()
57        # Special logic for Tensor PyObject resurrection
58        if hasattr(r, '_fix_weakref'):
59            r._fix_weakref()  # type: ignore[union-attr]
60        return r
61
62    def __hash__(self):
63        return self._id
64
65    def __eq__(self, other):
66        # An attractive but wrong alternate implementation is to only test if
67        # the stored _ids match.  This can lead to an ABA problem if you have:
68        #
69        #   a1 = A()
70        #   w1 = WeakIdRef(a1)
71        #   del a1
72        #   a2 = A()  # suppose it gets the same ID as a1
73        #   w2 = WeakIdRef(a2)
74        #   print(w1 == w2)
75        #
76        # This should be False, as a1 and a2 are unrelated (and a1 is
77        # dead anyway)
78        a = self()
79        b = other()
80        if a is not None and b is not None:
81            return a is b
82        return self is other
83
84# This is the same as WeakIdRef but equality is checked using hash() rather than id.
85# This will be equivalent to the one above except for classes where hash is not their id.
86class _WeakHashRef(weakref.ref):
87    __slots__ = ['_id']
88
89    def __init__(self, key, callback=None):
90        # Unlike stock weakref, which preserves hash semantics of the
91        # original object but lazily defers hash calls until the first
92        # time the user attempts to hash the weakref, we can eagerly
93        # cache the id of the key as we know this is definitely the hash
94        # method
95        self._id = hash(key)
96        super().__init__(key, callback)  # type: ignore[call-arg]
97
98    def __call__(self):
99        r = super().__call__()
100        # Special logic for Tensor PyObject resurrection
101        if hasattr(r, '_fix_weakref'):
102            r._fix_weakref()  # type: ignore[union-attr]
103        return r
104
105    def __hash__(self):
106        return self._id
107
108    def __eq__(self, other):
109        # Use hash equality to determine ref equality.
110        # ScriptObject implements __hash__ to return the wrapped IValue's id, so
111        # this is equivalent to doing an identity comparison.
112        a = self()
113        b = other()
114        if a is not None and b is not None:
115            return hash(a) == hash(b)
116        return self is other
117
118# This is directly adapted from cpython/Lib/weakref.py
119class WeakIdKeyDictionary(MutableMapping):
120    def __init__(self, dict=None, ref_type=WeakIdRef):  # CHANGED
121        self.data = {}
122
123        self.ref_type = ref_type  # CHANGED
124
125        def remove(k, selfref=ref(self)):
126            self = selfref()
127            if self is not None:
128                if self._iterating:
129                    self._pending_removals.append(k)
130                else:
131                    try:
132                        del self.data[k]
133                    except KeyError:
134                        pass
135        self._remove = remove
136        # A list of dead weakrefs (keys to be removed)
137        self._pending_removals = []
138        self._iterating = set()
139        self._dirty_len = False
140        if dict is not None:
141            self.update(dict)
142
143    def _commit_removals(self):
144        # NOTE: We don't need to call this method before mutating the dict,
145        # because a dead weakref never compares equal to a live weakref,
146        # even if they happened to refer to equal objects.
147        # However, it means keys may already have been removed.
148        pop = self._pending_removals.pop
149        d = self.data
150        while True:
151            try:
152                key = pop()
153            except IndexError:
154                return
155
156            try:
157                del d[key]
158            except KeyError:
159                pass
160
161    def _scrub_removals(self):
162        d = self.data
163        self._pending_removals = [k for k in self._pending_removals if k in d]
164        self._dirty_len = False
165
166    def __delitem__(self, key):
167        self._dirty_len = True
168        del self.data[self.ref_type(key)]  # CHANGED
169
170    def __getitem__(self, key):
171        return self.data[self.ref_type(key)]  # CHANGED
172
173    def __len__(self):
174        if self._dirty_len and self._pending_removals:
175            # self._pending_removals may still contain keys which were
176            # explicitly removed, we have to scrub them (see issue #21173).
177            self._scrub_removals()
178        return len(self.data) - len(self._pending_removals)
179
180    def __repr__(self):
181        return f"<{self.__class__.__name__} at {id(self):#x}>"
182
183    def __setitem__(self, key, value):
184        self.data[self.ref_type(key, self._remove)] = value  # CHANGED
185
186    def copy(self):
187        new = WeakIdKeyDictionary()
188        with _IterationGuard(self):
189            for key, value in self.data.items():
190                o = key()
191                if o is not None:
192                    new[o] = value
193        return new
194
195    __copy__ = copy
196
197    def __deepcopy__(self, memo):
198        from copy import deepcopy
199        new = self.__class__()
200        with _IterationGuard(self):
201            for key, value in self.data.items():
202                o = key()
203                if o is not None:
204                    new[o] = deepcopy(value, memo)
205        return new
206
207    def get(self, key, default=None):
208        return self.data.get(self.ref_type(key), default)  # CHANGED
209
210    def __contains__(self, key):
211        try:
212            wr = self.ref_type(key)  # CHANGED
213        except TypeError:
214            return False
215        return wr in self.data
216
217    def items(self):
218        with _IterationGuard(self):
219            for wr, value in self.data.items():
220                key = wr()
221                if key is not None:
222                    yield key, value
223
224    def keys(self):
225        with _IterationGuard(self):
226            for wr in self.data:
227                obj = wr()
228                if obj is not None:
229                    yield obj
230
231    __iter__ = keys
232
233    def values(self):
234        with _IterationGuard(self):
235            for wr, value in self.data.items():
236                if wr() is not None:
237                    yield value
238
239    def keyrefs(self):
240        """Return a list of weak references to the keys.
241
242        The references are not guaranteed to be 'live' at the time
243        they are used, so the result of calling the references needs
244        to be checked before being used.  This can be used to avoid
245        creating references that will cause the garbage collector to
246        keep the keys around longer than needed.
247
248        """
249        return list(self.data)
250
251    def popitem(self):
252        self._dirty_len = True
253        while True:
254            key, value = self.data.popitem()
255            o = key()
256            if o is not None:
257                return o, value
258
259    def pop(self, key, *args):
260        self._dirty_len = True
261        return self.data.pop(self.ref_type(key), *args)  # CHANGED
262
263    def setdefault(self, key, default=None):
264        return self.data.setdefault(self.ref_type(key, self._remove), default)  # CHANGED
265
266    def update(self, dict=None, **kwargs):
267        d = self.data
268        if dict is not None:
269            if not hasattr(dict, "items"):
270                dict = type({})(dict)
271            for key, value in dict.items():
272                d[self.ref_type(key, self._remove)] = value  # CHANGED
273        if len(kwargs):
274            self.update(kwargs)
275
276    def __ior__(self, other):
277        self.update(other)
278        return self
279
280    def __or__(self, other):
281        if isinstance(other, _collections_abc.Mapping):
282            c = self.copy()
283            c.update(other)
284            return c
285        return NotImplemented
286
287    def __ror__(self, other):
288        if isinstance(other, _collections_abc.Mapping):
289            c = self.__class__()
290            c.update(other)
291            c.update(self)
292            return c
293        return NotImplemented
294
295    # Default Mapping equality will tests keys for equality, but
296    # we want to test ids for equality
297    def __eq__(self, other):
298        if not isinstance(other, Mapping):
299            return NotImplemented
300        return {id(k): v for k, v in self.items()} == {id(k): v for k, v in other.items()}
301
302# Convenience alias
303WeakTensorKeyDictionary = WeakIdKeyDictionary
304
305
306class TensorWeakRef:
307    """Wrapper around a weak ref of a Tensor that handles the _fix_weakref() call required when unwrapping a Tensor weakref."""
308
309    ref: WeakRef[Tensor]
310
311    def __init__(self, tensor: Tensor):
312        assert isinstance(tensor, Tensor)
313        self.ref = weakref.ref(tensor)
314
315    def __call__(self):
316        out = self.ref()
317        if out is None:
318            return out
319        assert isinstance(out, Tensor)
320        # TODO, add _fix_weakref type binding
321        out._fix_weakref()  # type: ignore[attr-defined]
322        return out
323