xref: /aosp_15_r20/external/pytorch/test/test_weak.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: meta tensors"]
2
3import copy
4import gc
5import random
6import threading
7import unittest
8
9import torch
10from torch.testing._internal.common_utils import (
11    find_library_location,
12    IS_FBCODE,
13    IS_MACOS,
14    IS_SANDCASTLE,
15    IS_WINDOWS,
16    run_tests,
17    TestCase,
18)
19from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary
20
21
22def C():
23    return torch.randn(1)
24
25
26# These tests are ported from cpython/Lib/test/test_weakref.py,
27# but adapted to use tensor rather than object
28class WeakTest(TestCase):
29    COUNT = 10
30
31    def test_make_weak_keyed_dict_from_dict(self):
32        o = torch.randn(2)
33        dict = WeakIdKeyDictionary({o: 364})
34        self.assertEqual(dict[o], 364)
35
36    def test_make_weak_keyed_dict_from_weak_keyed_dict(self):
37        o = torch.randn(3)
38        dict = WeakIdKeyDictionary({o: 364})
39        dict2 = WeakIdKeyDictionary(dict)
40        self.assertEqual(dict[o], 364)
41
42    def check_popitem(self, klass, key1, value1, key2, value2):
43        weakdict = klass()
44        weakdict[key1] = value1
45        weakdict[key2] = value2
46        self.assertEqual(len(weakdict), 2)
47        k, v = weakdict.popitem()
48        self.assertEqual(len(weakdict), 1)
49        if k is key1:
50            self.assertIs(v, value1)
51        else:
52            self.assertIs(v, value2)
53        k, v = weakdict.popitem()
54        self.assertEqual(len(weakdict), 0)
55        if k is key1:
56            self.assertIs(v, value1)
57        else:
58            self.assertIs(v, value2)
59
60    def test_weak_keyed_dict_popitem(self):
61        self.check_popitem(WeakIdKeyDictionary, C(), "value 1", C(), "value 2")
62
63    def check_setdefault(self, klass, key, value1, value2):
64        self.assertIsNot(
65            value1,
66            value2,
67            "invalid test -- value parameters must be distinct objects",
68        )
69        weakdict = klass()
70        o = weakdict.setdefault(key, value1)
71        self.assertIs(o, value1)
72        self.assertIn(key, weakdict)
73        self.assertIs(weakdict.get(key), value1)
74        self.assertIs(weakdict[key], value1)
75
76        o = weakdict.setdefault(key, value2)
77        self.assertIs(o, value1)
78        self.assertIn(key, weakdict)
79        self.assertIs(weakdict.get(key), value1)
80        self.assertIs(weakdict[key], value1)
81
82    def test_weak_keyed_dict_setdefault(self):
83        self.check_setdefault(WeakIdKeyDictionary, C(), "value 1", "value 2")
84
85    def check_update(self, klass, dict):
86        #
87        #  This exercises d.update(), len(d), d.keys(), k in d,
88        #  d.get(), d[].
89        #
90        weakdict = klass()
91        weakdict.update(dict)
92        self.assertEqual(len(weakdict), len(dict))
93        for k in weakdict.keys():
94            self.assertIn(k, dict, "mysterious new key appeared in weak dict")
95            v = dict.get(k)
96            self.assertIs(v, weakdict[k])
97            self.assertIs(v, weakdict.get(k))
98        for k in dict.keys():
99            self.assertIn(k, weakdict, "original key disappeared in weak dict")
100            v = dict[k]
101            self.assertIs(v, weakdict[k])
102            self.assertIs(v, weakdict.get(k))
103
104    def test_weak_keyed_dict_update(self):
105        self.check_update(WeakIdKeyDictionary, {C(): 1, C(): 2, C(): 3})
106
107    def test_weak_keyed_delitem(self):
108        d = WeakIdKeyDictionary()
109        o1 = torch.randn(1)
110        o2 = torch.randn(2)
111        d[o1] = "something"
112        d[o2] = "something"
113        self.assertEqual(len(d), 2)
114        del d[o1]
115        self.assertEqual(len(d), 1)
116        self.assertEqual(list(d.keys()), [o2])
117
118    def test_weak_keyed_union_operators(self):
119        try:
120            {} | {}
121        except TypeError:
122            self.skipTest("dict union not supported in this Python")
123
124        o1 = C()
125        o2 = C()
126        o3 = C()
127        wkd1 = WeakIdKeyDictionary({o1: 1, o2: 2})
128        wkd2 = WeakIdKeyDictionary({o3: 3, o1: 4})
129        wkd3 = wkd1.copy()
130        d1 = {o2: "5", o3: "6"}
131        pairs = [(o2, 7), (o3, 8)]
132
133        tmp1 = wkd1 | wkd2  # Between two WeakKeyDictionaries
134        self.assertEqual(dict(tmp1), dict(wkd1) | dict(wkd2))
135        self.assertIs(type(tmp1), WeakIdKeyDictionary)
136        wkd1 |= wkd2
137        self.assertEqual(wkd1, tmp1)
138
139        tmp2 = wkd2 | d1  # Between WeakKeyDictionary and mapping
140        self.assertEqual(dict(tmp2), dict(wkd2) | d1)
141        self.assertIs(type(tmp2), WeakIdKeyDictionary)
142        wkd2 |= d1
143        self.assertEqual(wkd2, tmp2)
144
145        tmp3 = wkd3.copy()  # Between WeakKeyDictionary and iterable key, value
146        tmp3 |= pairs
147        self.assertEqual(dict(tmp3), dict(wkd3) | dict(pairs))
148        self.assertIs(type(tmp3), WeakIdKeyDictionary)
149
150        tmp4 = d1 | wkd3  # Testing .__ror__
151        self.assertEqual(dict(tmp4), d1 | dict(wkd3))
152        self.assertIs(type(tmp4), WeakIdKeyDictionary)
153
154        del o1
155        self.assertNotIn(4, tmp1.values())
156        self.assertNotIn(4, tmp2.values())
157        self.assertNotIn(1, tmp3.values())
158        self.assertNotIn(1, tmp4.values())
159
160    def test_weak_keyed_bad_delitem(self):
161        d = WeakIdKeyDictionary()
162        o = torch.randn(1)
163        # An attempt to delete an object that isn't there should raise
164        # KeyError.  It didn't before 2.3.
165        self.assertRaises(KeyError, d.__delitem__, o)
166        self.assertRaises(KeyError, d.__getitem__, o)
167
168        # If a key isn't of a weakly referencable type, __getitem__ and
169        # __setitem__ raise TypeError.  __delitem__ should too.
170        self.assertRaises(TypeError, d.__delitem__, 13)
171        self.assertRaises(TypeError, d.__getitem__, 13)
172        self.assertRaises(TypeError, d.__setitem__, 13, 13)
173
174    def test_make_weak_keyed_dict_repr(self):
175        dict = WeakIdKeyDictionary()
176        self.assertRegex(repr(dict), "<WeakIdKeyDictionary at 0x.*>")
177
178    def check_threaded_weak_dict_copy(self, type_, deepcopy):
179        # `deepcopy` should be either True or False.
180        exc = []
181
182        # Cannot give these slots as weakrefs weren't supported
183        # on these objects until later versions of Python
184        class DummyKey:  # noqa: B903
185            def __init__(self, ctr):
186                self.ctr = ctr
187
188        class DummyValue:  # noqa: B903
189            def __init__(self, ctr):
190                self.ctr = ctr
191
192        def dict_copy(d, exc):
193            try:
194                if deepcopy is True:
195                    _ = copy.deepcopy(d)
196                else:
197                    _ = d.copy()
198            except Exception as ex:
199                exc.append(ex)
200
201        def pop_and_collect(lst):
202            gc_ctr = 0
203            while lst:
204                i = random.randint(0, len(lst) - 1)
205                gc_ctr += 1
206                lst.pop(i)
207                if gc_ctr % 10000 == 0:
208                    gc.collect()  # just in case
209
210        d = type_()
211        keys = []
212        values = []
213        # Initialize d with many entries
214        for i in range(70000):
215            k, v = DummyKey(i), DummyValue(i)
216            keys.append(k)
217            values.append(v)
218            d[k] = v
219            del k
220            del v
221
222        t_copy = threading.Thread(target=dict_copy, args=(d, exc))
223        t_collect = threading.Thread(target=pop_and_collect, args=(keys,))
224
225        t_copy.start()
226        t_collect.start()
227
228        t_copy.join()
229        t_collect.join()
230
231        # Test exceptions
232        if exc:
233            raise exc[0]
234
235    def test_threaded_weak_key_dict_copy(self):
236        # Issue #35615: Weakref keys or values getting GC'ed during dict
237        # copying should not result in a crash.
238        self.check_threaded_weak_dict_copy(WeakIdKeyDictionary, False)
239
240    def test_threaded_weak_key_dict_deepcopy(self):
241        # Issue #35615: Weakref keys or values getting GC'ed during dict
242        # copying should not result in a crash.
243        self.check_threaded_weak_dict_copy(WeakIdKeyDictionary, True)
244
245
246# Adapted from cpython/Lib/test/mapping_tests.py
247class WeakKeyDictionaryTestCase(TestCase):
248    __ref = {torch.randn(1): 1, torch.randn(2): 2, torch.randn(3): 3}
249    type2test = WeakIdKeyDictionary
250
251    def _reference(self):
252        return self.__ref.copy()
253
254    def _empty_mapping(self):
255        """Return an empty mapping object"""
256        return self.type2test()
257
258    def _full_mapping(self, data):
259        """Return a mapping object with the value contained in data
260        dictionary"""
261        x = self._empty_mapping()
262        for key, value in data.items():
263            x[key] = value
264        return x
265
266    def __init__(self, *args, **kw):
267        unittest.TestCase.__init__(self, *args, **kw)
268        self.reference = self._reference().copy()
269
270        # A (key, value) pair not in the mapping
271        key, value = self.reference.popitem()
272        self.other = {key: value}
273
274        # A (key, value) pair in the mapping
275        key, value = self.reference.popitem()
276        self.inmapping = {key: value}
277        self.reference[key] = value
278
279    def test_read(self):
280        # Test for read only operations on mapping
281        p = self._empty_mapping()
282        p1 = dict(p)  # workaround for singleton objects
283        d = self._full_mapping(self.reference)
284        if d is p:
285            p = p1
286        # Indexing
287        for key, value in self.reference.items():
288            self.assertEqual(d[key], value)
289        knownkey = next(iter(self.other.keys()))
290        self.assertRaises(KeyError, lambda: d[knownkey])
291        # len
292        self.assertEqual(len(p), 0)
293        self.assertEqual(len(d), len(self.reference))
294        # __contains__
295        for k in self.reference:
296            self.assertIn(k, d)
297        for k in self.other:
298            self.assertNotIn(k, d)
299        # cmp
300        self.assertTrue(
301            p == p
302        )  # NB: don't use assertEqual, that doesn't actually use ==
303        self.assertTrue(d == d)
304        self.assertTrue(p != d)
305        self.assertTrue(d != p)
306        # bool
307        if p:
308            self.fail("Empty mapping must compare to False")
309        if not d:
310            self.fail("Full mapping must compare to True")
311
312        # keys(), items(), iterkeys() ...
313        def check_iterandlist(iter, lst, ref):
314            self.assertTrue(hasattr(iter, "__next__"))
315            self.assertTrue(hasattr(iter, "__iter__"))
316            x = list(iter)
317            self.assertTrue(set(x) == set(lst) == set(ref))
318
319        check_iterandlist(iter(d.keys()), list(d.keys()), self.reference.keys())
320        check_iterandlist(iter(d), list(d.keys()), self.reference.keys())
321        check_iterandlist(iter(d.values()), list(d.values()), self.reference.values())
322        check_iterandlist(iter(d.items()), list(d.items()), self.reference.items())
323        # get
324        key, value = next(iter(d.items()))
325        knownkey, knownvalue = next(iter(self.other.items()))
326        self.assertEqual(d.get(key, knownvalue), value)
327        self.assertEqual(d.get(knownkey, knownvalue), knownvalue)
328        self.assertNotIn(knownkey, d)
329
330    def test_write(self):
331        # Test for write operations on mapping
332        p = self._empty_mapping()
333        # Indexing
334        for key, value in self.reference.items():
335            p[key] = value
336            self.assertEqual(p[key], value)
337        for key in self.reference.keys():
338            del p[key]
339            self.assertRaises(KeyError, lambda: p[key])
340        p = self._empty_mapping()
341        # update
342        p.update(self.reference)
343        self.assertEqual(dict(p), self.reference)
344        items = list(p.items())
345        p = self._empty_mapping()
346        p.update(items)
347        self.assertEqual(dict(p), self.reference)
348        d = self._full_mapping(self.reference)
349        # setdefault
350        key, value = next(iter(d.items()))
351        knownkey, knownvalue = next(iter(self.other.items()))
352        self.assertEqual(d.setdefault(key, knownvalue), value)
353        self.assertEqual(d[key], value)
354        self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue)
355        self.assertEqual(d[knownkey], knownvalue)
356        # pop
357        self.assertEqual(d.pop(knownkey), knownvalue)
358        self.assertNotIn(knownkey, d)
359        self.assertRaises(KeyError, d.pop, knownkey)
360        default = 909
361        d[knownkey] = knownvalue
362        self.assertEqual(d.pop(knownkey, default), knownvalue)
363        self.assertNotIn(knownkey, d)
364        self.assertEqual(d.pop(knownkey, default), default)
365        # popitem
366        key, value = d.popitem()
367        self.assertNotIn(key, d)
368        self.assertEqual(value, self.reference[key])
369        p = self._empty_mapping()
370        self.assertRaises(KeyError, p.popitem)
371
372    def test_constructor(self):
373        self.assertEqual(self._empty_mapping(), self._empty_mapping())
374
375    def test_bool(self):
376        self.assertTrue(not self._empty_mapping())
377        self.assertTrue(self.reference)
378        self.assertTrue(bool(self._empty_mapping()) is False)
379        self.assertTrue(bool(self.reference) is True)
380
381    def test_keys(self):
382        d = self._empty_mapping()
383        self.assertEqual(list(d.keys()), [])
384        d = self.reference
385        self.assertIn(next(iter(self.inmapping.keys())), d.keys())
386        self.assertNotIn(next(iter(self.other.keys())), d.keys())
387        self.assertRaises(TypeError, d.keys, None)
388
389    def test_values(self):
390        d = self._empty_mapping()
391        self.assertEqual(list(d.values()), [])
392
393        self.assertRaises(TypeError, d.values, None)
394
395    def test_items(self):
396        d = self._empty_mapping()
397        self.assertEqual(list(d.items()), [])
398
399        self.assertRaises(TypeError, d.items, None)
400
401    def test_len(self):
402        d = self._empty_mapping()
403        self.assertEqual(len(d), 0)
404
405    def test_getitem(self):
406        d = self.reference
407        self.assertEqual(
408            d[next(iter(self.inmapping.keys()))], next(iter(self.inmapping.values()))
409        )
410
411        self.assertRaises(TypeError, d.__getitem__)
412
413    def test_update(self):
414        # mapping argument
415        d = self._empty_mapping()
416        d.update(self.other)
417        self.assertEqual(list(d.items()), list(self.other.items()))
418
419        # No argument
420        d = self._empty_mapping()
421        d.update()
422        self.assertEqual(d, self._empty_mapping())
423
424        # item sequence
425        d = self._empty_mapping()
426        d.update(self.other.items())
427        self.assertEqual(list(d.items()), list(self.other.items()))
428
429        # Iterator
430        d = self._empty_mapping()
431        d.update(self.other.items())
432        self.assertEqual(list(d.items()), list(self.other.items()))
433
434        # FIXME: Doesn't work with UserDict
435        # self.assertRaises((TypeError, AttributeError), d.update, None)
436        self.assertRaises((TypeError, AttributeError), d.update, 42)
437
438        outerself = self
439
440        class SimpleUserDict:
441            def __init__(self) -> None:
442                self.d = outerself.reference
443
444            def keys(self):
445                return self.d.keys()
446
447            def __getitem__(self, i):
448                return self.d[i]
449
450        d.clear()
451        d.update(SimpleUserDict())
452        i1 = sorted((id(k), v) for k, v in d.items())
453        i2 = sorted((id(k), v) for k, v in self.reference.items())
454        self.assertEqual(i1, i2)
455
456        class Exc(Exception):
457            pass
458
459        d = self._empty_mapping()
460
461        class FailingUserDict:
462            def keys(self):
463                raise Exc
464
465        self.assertRaises(Exc, d.update, FailingUserDict())
466
467        d.clear()
468
469        class FailingUserDict:
470            def keys(self):
471                class BogonIter:
472                    def __init__(self) -> None:
473                        self.i = 1
474
475                    def __iter__(self):
476                        return self
477
478                    def __next__(self):
479                        if self.i:
480                            self.i = 0
481                            return "a"
482                        raise Exc
483
484                return BogonIter()
485
486            def __getitem__(self, key):
487                return key
488
489        self.assertRaises(Exc, d.update, FailingUserDict())
490
491        class FailingUserDict:
492            def keys(self):
493                class BogonIter:
494                    def __init__(self) -> None:
495                        self.i = ord("a")
496
497                    def __iter__(self):
498                        return self
499
500                    def __next__(self):
501                        if self.i <= ord("z"):
502                            rtn = chr(self.i)
503                            self.i += 1
504                            return rtn
505                        raise StopIteration
506
507                return BogonIter()
508
509            def __getitem__(self, key):
510                raise Exc
511
512        self.assertRaises(Exc, d.update, FailingUserDict())
513
514        d = self._empty_mapping()
515
516        class badseq:
517            def __iter__(self):
518                return self
519
520            def __next__(self):
521                raise Exc
522
523        self.assertRaises(Exc, d.update, badseq())
524
525        self.assertRaises(ValueError, d.update, [(1, 2, 3)])
526
527    # no test_fromkeys or test_copy as both os.environ and selves don't support it
528
529    def test_get(self):
530        d = self._empty_mapping()
531        self.assertTrue(d.get(next(iter(self.other.keys()))) is None)
532        self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3)
533        d = self.reference
534        self.assertTrue(d.get(next(iter(self.other.keys()))) is None)
535        self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3)
536        self.assertEqual(
537            d.get(next(iter(self.inmapping.keys()))),
538            next(iter(self.inmapping.values())),
539        )
540        self.assertEqual(
541            d.get(next(iter(self.inmapping.keys())), 3),
542            next(iter(self.inmapping.values())),
543        )
544        self.assertRaises(TypeError, d.get)
545        self.assertRaises(TypeError, d.get, None, None, None)
546
547    def test_setdefault(self):
548        d = self._empty_mapping()
549        self.assertRaises(TypeError, d.setdefault)
550
551    def test_popitem(self):
552        d = self._empty_mapping()
553        self.assertRaises(KeyError, d.popitem)
554        self.assertRaises(TypeError, d.popitem, 42)
555
556    def test_pop(self):
557        d = self._empty_mapping()
558        k, v = next(iter(self.inmapping.items()))
559        d[k] = v
560        self.assertRaises(KeyError, d.pop, next(iter(self.other.keys())))
561
562        self.assertEqual(d.pop(k), v)
563        self.assertEqual(len(d), 0)
564
565        self.assertRaises(KeyError, d.pop, k)
566
567
568# Adapted from cpython/Lib/test/mapping_tests.py
569class WeakKeyDictionaryScriptObjectTestCase(TestCase):
570    def _reference(self):
571        self.__ref = {
572            torch.classes._TorchScriptTesting._Foo(1, 2): 1,
573            torch.classes._TorchScriptTesting._Foo(2, 3): 2,
574            torch.classes._TorchScriptTesting._Foo(3, 4): 3,
575        }
576        return self.__ref.copy()
577
578    def _empty_mapping(self):
579        """Return an empty mapping object"""
580        return WeakIdKeyDictionary(ref_type=_WeakHashRef)
581
582    def _full_mapping(self, data):
583        """Return a mapping object with the value contained in data
584        dictionary"""
585        x = self._empty_mapping()
586        for key, value in data.items():
587            x[key] = value
588        return x
589
590    def setUp(self):
591        if IS_MACOS:
592            raise unittest.SkipTest("non-portable load_library call used in test")
593
594    def __init__(self, *args, **kw):
595        unittest.TestCase.__init__(self, *args, **kw)
596        if IS_SANDCASTLE or IS_FBCODE:
597            torch.ops.load_library(
598                "//caffe2/test/cpp/jit:test_custom_class_registrations"
599            )
600        elif IS_MACOS:
601            # don't load the library, just skip the tests in setUp
602            return
603        else:
604            lib_file_path = find_library_location("libtorchbind_test.so")
605            if IS_WINDOWS:
606                lib_file_path = find_library_location("torchbind_test.dll")
607            torch.ops.load_library(str(lib_file_path))
608
609        self.reference = self._reference().copy()
610
611        # A (key, value) pair not in the mapping
612        key, value = self.reference.popitem()
613        self.other = {key: value}
614
615        # A (key, value) pair in the mapping
616        key, value = self.reference.popitem()
617        self.inmapping = {key: value}
618        self.reference[key] = value
619
620    def test_read(self):
621        # Test for read only operations on mapping
622        p = self._empty_mapping()
623        p1 = dict(p)  # workaround for singleton objects
624        d = self._full_mapping(self.reference)
625        if d is p:
626            p = p1
627        # Indexing
628        for key, value in self.reference.items():
629            self.assertEqual(d[key], value)
630        knownkey = next(iter(self.other.keys()))
631        self.assertRaises(KeyError, lambda: d[knownkey])
632        # len
633        self.assertEqual(len(p), 0)
634        self.assertEqual(len(d), len(self.reference))
635        # __contains__
636        for k in self.reference:
637            self.assertIn(k, d)
638        for k in self.other:
639            self.assertNotIn(k, d)
640        # cmp
641        self.assertTrue(
642            p == p
643        )  # NB: don't use assertEqual, that doesn't actually use ==
644        self.assertTrue(d == d)
645        self.assertTrue(p != d)
646        self.assertTrue(d != p)
647        # bool
648        if p:
649            self.fail("Empty mapping must compare to False")
650        if not d:
651            self.fail("Full mapping must compare to True")
652
653        # keys(), items(), iterkeys() ...
654        def check_iterandlist(iter, lst, ref):
655            self.assertTrue(hasattr(iter, "__next__"))
656            self.assertTrue(hasattr(iter, "__iter__"))
657            x = list(iter)
658            self.assertTrue(set(x) == set(lst) == set(ref))
659
660        check_iterandlist(iter(d.keys()), list(d.keys()), self.reference.keys())
661        check_iterandlist(iter(d), list(d.keys()), self.reference.keys())
662        check_iterandlist(iter(d.values()), list(d.values()), self.reference.values())
663        check_iterandlist(iter(d.items()), list(d.items()), self.reference.items())
664        # get
665        key, value = next(iter(d.items()))
666        knownkey, knownvalue = next(iter(self.other.items()))
667        self.assertEqual(d.get(key, knownvalue), value)
668        self.assertEqual(d.get(knownkey, knownvalue), knownvalue)
669        self.assertNotIn(knownkey, d)
670
671    def test_write(self):
672        # Test for write operations on mapping
673        p = self._empty_mapping()
674        # Indexing
675        for key, value in self.reference.items():
676            p[key] = value
677            self.assertEqual(p[key], value)
678        for key in self.reference.keys():
679            del p[key]
680            self.assertRaises(KeyError, lambda: p[key])
681        p = self._empty_mapping()
682        # update
683        p.update(self.reference)
684        self.assertEqual(dict(p), self.reference)
685        items = list(p.items())
686        p = self._empty_mapping()
687        p.update(items)
688        self.assertEqual(dict(p), self.reference)
689        d = self._full_mapping(self.reference)
690        # setdefault
691        key, value = next(iter(d.items()))
692        knownkey, knownvalue = next(iter(self.other.items()))
693        self.assertEqual(d.setdefault(key, knownvalue), value)
694        self.assertEqual(d[key], value)
695        self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue)
696        self.assertEqual(d[knownkey], knownvalue)
697        # pop
698        self.assertEqual(d.pop(knownkey), knownvalue)
699        self.assertNotIn(knownkey, d)
700        self.assertRaises(KeyError, d.pop, knownkey)
701        default = 909
702        d[knownkey] = knownvalue
703        self.assertEqual(d.pop(knownkey, default), knownvalue)
704        self.assertNotIn(knownkey, d)
705        self.assertEqual(d.pop(knownkey, default), default)
706        # popitem
707        key, value = d.popitem()
708        self.assertNotIn(key, d)
709        self.assertEqual(value, self.reference[key])
710        p = self._empty_mapping()
711        self.assertRaises(KeyError, p.popitem)
712
713    def test_constructor(self):
714        self.assertEqual(self._empty_mapping(), self._empty_mapping())
715
716    def test_bool(self):
717        self.assertTrue(not self._empty_mapping())
718        self.assertTrue(self.reference)
719        self.assertTrue(bool(self._empty_mapping()) is False)
720        self.assertTrue(bool(self.reference) is True)
721
722    def test_keys(self):
723        d = self._empty_mapping()
724        self.assertEqual(list(d.keys()), [])
725        d = self.reference
726        self.assertIn(next(iter(self.inmapping.keys())), d.keys())
727        self.assertNotIn(next(iter(self.other.keys())), d.keys())
728        self.assertRaises(TypeError, d.keys, None)
729
730    def test_values(self):
731        d = self._empty_mapping()
732        self.assertEqual(list(d.values()), [])
733
734        self.assertRaises(TypeError, d.values, None)
735
736    def test_items(self):
737        d = self._empty_mapping()
738        self.assertEqual(list(d.items()), [])
739
740        self.assertRaises(TypeError, d.items, None)
741
742    def test_len(self):
743        d = self._empty_mapping()
744        self.assertEqual(len(d), 0)
745
746    def test_getitem(self):
747        d = self.reference
748        self.assertEqual(
749            d[next(iter(self.inmapping.keys()))], next(iter(self.inmapping.values()))
750        )
751
752        self.assertRaises(TypeError, d.__getitem__)
753
754    def test_update(self):
755        # mapping argument
756        d = self._empty_mapping()
757        d.update(self.other)
758        self.assertEqual(list(d.items()), list(self.other.items()))
759
760        # No argument
761        d = self._empty_mapping()
762        d.update()
763        self.assertEqual(d, self._empty_mapping())
764
765        # item sequence
766        d = self._empty_mapping()
767        d.update(self.other.items())
768        self.assertEqual(list(d.items()), list(self.other.items()))
769
770        # Iterator
771        d = self._empty_mapping()
772        d.update(self.other.items())
773        self.assertEqual(list(d.items()), list(self.other.items()))
774
775        # FIXME: Doesn't work with UserDict
776        # self.assertRaises((TypeError, AttributeError), d.update, None)
777        self.assertRaises((TypeError, AttributeError), d.update, 42)
778
779        outerself = self
780
781        class SimpleUserDict:
782            def __init__(self) -> None:
783                self.d = outerself.reference
784
785            def keys(self):
786                return self.d.keys()
787
788            def __getitem__(self, i):
789                return self.d[i]
790
791        d.clear()
792        d.update(SimpleUserDict())
793        i1 = sorted((id(k), v) for k, v in d.items())
794        i2 = sorted((id(k), v) for k, v in self.reference.items())
795        self.assertEqual(i1, i2)
796
797        class Exc(Exception):
798            pass
799
800        d = self._empty_mapping()
801
802        class FailingUserDict:
803            def keys(self):
804                raise Exc
805
806        self.assertRaises(Exc, d.update, FailingUserDict())
807
808        d.clear()
809
810        class FailingUserDict:
811            def keys(self):
812                class BogonIter:
813                    def __init__(self) -> None:
814                        self.i = 1
815
816                    def __iter__(self):
817                        return self
818
819                    def __next__(self):
820                        if self.i:
821                            self.i = 0
822                            return "a"
823                        raise Exc
824
825                return BogonIter()
826
827            def __getitem__(self, key):
828                return key
829
830        self.assertRaises(Exc, d.update, FailingUserDict())
831
832        class FailingUserDict:
833            def keys(self):
834                class BogonIter:
835                    def __init__(self) -> None:
836                        self.i = ord("a")
837
838                    def __iter__(self):
839                        return self
840
841                    def __next__(self):
842                        if self.i <= ord("z"):
843                            rtn = chr(self.i)
844                            self.i += 1
845                            return rtn
846                        raise StopIteration
847
848                return BogonIter()
849
850            def __getitem__(self, key):
851                raise Exc
852
853        self.assertRaises(Exc, d.update, FailingUserDict())
854
855        d = self._empty_mapping()
856
857        class badseq:
858            def __iter__(self):
859                return self
860
861            def __next__(self):
862                raise Exc
863
864        self.assertRaises(Exc, d.update, badseq())
865
866        self.assertRaises(ValueError, d.update, [(1, 2, 3)])
867
868    # no test_fromkeys or test_copy as both os.environ and selves don't support it
869
870    def test_get(self):
871        d = self._empty_mapping()
872        self.assertTrue(d.get(next(iter(self.other.keys()))) is None)
873        self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3)
874        d = self.reference
875        self.assertTrue(d.get(next(iter(self.other.keys()))) is None)
876        self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3)
877        self.assertEqual(
878            d.get(next(iter(self.inmapping.keys()))),
879            next(iter(self.inmapping.values())),
880        )
881        self.assertEqual(
882            d.get(next(iter(self.inmapping.keys())), 3),
883            next(iter(self.inmapping.values())),
884        )
885        self.assertRaises(TypeError, d.get)
886        self.assertRaises(TypeError, d.get, None, None, None)
887
888    def test_setdefault(self):
889        d = self._empty_mapping()
890        self.assertRaises(TypeError, d.setdefault)
891
892    def test_popitem(self):
893        d = self._empty_mapping()
894        self.assertRaises(KeyError, d.popitem)
895        self.assertRaises(TypeError, d.popitem, 42)
896
897    def test_pop(self):
898        d = self._empty_mapping()
899        k, v = next(iter(self.inmapping.items()))
900        d[k] = v
901        self.assertRaises(KeyError, d.pop, next(iter(self.other.keys())))
902
903        self.assertEqual(d.pop(k), v)
904        self.assertEqual(len(d), 0)
905
906        self.assertRaises(KeyError, d.pop, k)
907
908
909if __name__ == "__main__":
910    run_tests()
911