1# Owner(s): ["module: meta tensors"] 2 3import copy 4import gc 5import random 6import threading 7import unittest 8 9import torch 10from torch.testing._internal.common_utils import ( 11 find_library_location, 12 IS_FBCODE, 13 IS_MACOS, 14 IS_SANDCASTLE, 15 IS_WINDOWS, 16 run_tests, 17 TestCase, 18) 19from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary 20 21 22def C(): 23 return torch.randn(1) 24 25 26# These tests are ported from cpython/Lib/test/test_weakref.py, 27# but adapted to use tensor rather than object 28class WeakTest(TestCase): 29 COUNT = 10 30 31 def test_make_weak_keyed_dict_from_dict(self): 32 o = torch.randn(2) 33 dict = WeakIdKeyDictionary({o: 364}) 34 self.assertEqual(dict[o], 364) 35 36 def test_make_weak_keyed_dict_from_weak_keyed_dict(self): 37 o = torch.randn(3) 38 dict = WeakIdKeyDictionary({o: 364}) 39 dict2 = WeakIdKeyDictionary(dict) 40 self.assertEqual(dict[o], 364) 41 42 def check_popitem(self, klass, key1, value1, key2, value2): 43 weakdict = klass() 44 weakdict[key1] = value1 45 weakdict[key2] = value2 46 self.assertEqual(len(weakdict), 2) 47 k, v = weakdict.popitem() 48 self.assertEqual(len(weakdict), 1) 49 if k is key1: 50 self.assertIs(v, value1) 51 else: 52 self.assertIs(v, value2) 53 k, v = weakdict.popitem() 54 self.assertEqual(len(weakdict), 0) 55 if k is key1: 56 self.assertIs(v, value1) 57 else: 58 self.assertIs(v, value2) 59 60 def test_weak_keyed_dict_popitem(self): 61 self.check_popitem(WeakIdKeyDictionary, C(), "value 1", C(), "value 2") 62 63 def check_setdefault(self, klass, key, value1, value2): 64 self.assertIsNot( 65 value1, 66 value2, 67 "invalid test -- value parameters must be distinct objects", 68 ) 69 weakdict = klass() 70 o = weakdict.setdefault(key, value1) 71 self.assertIs(o, value1) 72 self.assertIn(key, weakdict) 73 self.assertIs(weakdict.get(key), value1) 74 self.assertIs(weakdict[key], value1) 75 76 o = weakdict.setdefault(key, value2) 77 self.assertIs(o, value1) 78 self.assertIn(key, weakdict) 79 self.assertIs(weakdict.get(key), value1) 80 self.assertIs(weakdict[key], value1) 81 82 def test_weak_keyed_dict_setdefault(self): 83 self.check_setdefault(WeakIdKeyDictionary, C(), "value 1", "value 2") 84 85 def check_update(self, klass, dict): 86 # 87 # This exercises d.update(), len(d), d.keys(), k in d, 88 # d.get(), d[]. 89 # 90 weakdict = klass() 91 weakdict.update(dict) 92 self.assertEqual(len(weakdict), len(dict)) 93 for k in weakdict.keys(): 94 self.assertIn(k, dict, "mysterious new key appeared in weak dict") 95 v = dict.get(k) 96 self.assertIs(v, weakdict[k]) 97 self.assertIs(v, weakdict.get(k)) 98 for k in dict.keys(): 99 self.assertIn(k, weakdict, "original key disappeared in weak dict") 100 v = dict[k] 101 self.assertIs(v, weakdict[k]) 102 self.assertIs(v, weakdict.get(k)) 103 104 def test_weak_keyed_dict_update(self): 105 self.check_update(WeakIdKeyDictionary, {C(): 1, C(): 2, C(): 3}) 106 107 def test_weak_keyed_delitem(self): 108 d = WeakIdKeyDictionary() 109 o1 = torch.randn(1) 110 o2 = torch.randn(2) 111 d[o1] = "something" 112 d[o2] = "something" 113 self.assertEqual(len(d), 2) 114 del d[o1] 115 self.assertEqual(len(d), 1) 116 self.assertEqual(list(d.keys()), [o2]) 117 118 def test_weak_keyed_union_operators(self): 119 try: 120 {} | {} 121 except TypeError: 122 self.skipTest("dict union not supported in this Python") 123 124 o1 = C() 125 o2 = C() 126 o3 = C() 127 wkd1 = WeakIdKeyDictionary({o1: 1, o2: 2}) 128 wkd2 = WeakIdKeyDictionary({o3: 3, o1: 4}) 129 wkd3 = wkd1.copy() 130 d1 = {o2: "5", o3: "6"} 131 pairs = [(o2, 7), (o3, 8)] 132 133 tmp1 = wkd1 | wkd2 # Between two WeakKeyDictionaries 134 self.assertEqual(dict(tmp1), dict(wkd1) | dict(wkd2)) 135 self.assertIs(type(tmp1), WeakIdKeyDictionary) 136 wkd1 |= wkd2 137 self.assertEqual(wkd1, tmp1) 138 139 tmp2 = wkd2 | d1 # Between WeakKeyDictionary and mapping 140 self.assertEqual(dict(tmp2), dict(wkd2) | d1) 141 self.assertIs(type(tmp2), WeakIdKeyDictionary) 142 wkd2 |= d1 143 self.assertEqual(wkd2, tmp2) 144 145 tmp3 = wkd3.copy() # Between WeakKeyDictionary and iterable key, value 146 tmp3 |= pairs 147 self.assertEqual(dict(tmp3), dict(wkd3) | dict(pairs)) 148 self.assertIs(type(tmp3), WeakIdKeyDictionary) 149 150 tmp4 = d1 | wkd3 # Testing .__ror__ 151 self.assertEqual(dict(tmp4), d1 | dict(wkd3)) 152 self.assertIs(type(tmp4), WeakIdKeyDictionary) 153 154 del o1 155 self.assertNotIn(4, tmp1.values()) 156 self.assertNotIn(4, tmp2.values()) 157 self.assertNotIn(1, tmp3.values()) 158 self.assertNotIn(1, tmp4.values()) 159 160 def test_weak_keyed_bad_delitem(self): 161 d = WeakIdKeyDictionary() 162 o = torch.randn(1) 163 # An attempt to delete an object that isn't there should raise 164 # KeyError. It didn't before 2.3. 165 self.assertRaises(KeyError, d.__delitem__, o) 166 self.assertRaises(KeyError, d.__getitem__, o) 167 168 # If a key isn't of a weakly referencable type, __getitem__ and 169 # __setitem__ raise TypeError. __delitem__ should too. 170 self.assertRaises(TypeError, d.__delitem__, 13) 171 self.assertRaises(TypeError, d.__getitem__, 13) 172 self.assertRaises(TypeError, d.__setitem__, 13, 13) 173 174 def test_make_weak_keyed_dict_repr(self): 175 dict = WeakIdKeyDictionary() 176 self.assertRegex(repr(dict), "<WeakIdKeyDictionary at 0x.*>") 177 178 def check_threaded_weak_dict_copy(self, type_, deepcopy): 179 # `deepcopy` should be either True or False. 180 exc = [] 181 182 # Cannot give these slots as weakrefs weren't supported 183 # on these objects until later versions of Python 184 class DummyKey: # noqa: B903 185 def __init__(self, ctr): 186 self.ctr = ctr 187 188 class DummyValue: # noqa: B903 189 def __init__(self, ctr): 190 self.ctr = ctr 191 192 def dict_copy(d, exc): 193 try: 194 if deepcopy is True: 195 _ = copy.deepcopy(d) 196 else: 197 _ = d.copy() 198 except Exception as ex: 199 exc.append(ex) 200 201 def pop_and_collect(lst): 202 gc_ctr = 0 203 while lst: 204 i = random.randint(0, len(lst) - 1) 205 gc_ctr += 1 206 lst.pop(i) 207 if gc_ctr % 10000 == 0: 208 gc.collect() # just in case 209 210 d = type_() 211 keys = [] 212 values = [] 213 # Initialize d with many entries 214 for i in range(70000): 215 k, v = DummyKey(i), DummyValue(i) 216 keys.append(k) 217 values.append(v) 218 d[k] = v 219 del k 220 del v 221 222 t_copy = threading.Thread(target=dict_copy, args=(d, exc)) 223 t_collect = threading.Thread(target=pop_and_collect, args=(keys,)) 224 225 t_copy.start() 226 t_collect.start() 227 228 t_copy.join() 229 t_collect.join() 230 231 # Test exceptions 232 if exc: 233 raise exc[0] 234 235 def test_threaded_weak_key_dict_copy(self): 236 # Issue #35615: Weakref keys or values getting GC'ed during dict 237 # copying should not result in a crash. 238 self.check_threaded_weak_dict_copy(WeakIdKeyDictionary, False) 239 240 def test_threaded_weak_key_dict_deepcopy(self): 241 # Issue #35615: Weakref keys or values getting GC'ed during dict 242 # copying should not result in a crash. 243 self.check_threaded_weak_dict_copy(WeakIdKeyDictionary, True) 244 245 246# Adapted from cpython/Lib/test/mapping_tests.py 247class WeakKeyDictionaryTestCase(TestCase): 248 __ref = {torch.randn(1): 1, torch.randn(2): 2, torch.randn(3): 3} 249 type2test = WeakIdKeyDictionary 250 251 def _reference(self): 252 return self.__ref.copy() 253 254 def _empty_mapping(self): 255 """Return an empty mapping object""" 256 return self.type2test() 257 258 def _full_mapping(self, data): 259 """Return a mapping object with the value contained in data 260 dictionary""" 261 x = self._empty_mapping() 262 for key, value in data.items(): 263 x[key] = value 264 return x 265 266 def __init__(self, *args, **kw): 267 unittest.TestCase.__init__(self, *args, **kw) 268 self.reference = self._reference().copy() 269 270 # A (key, value) pair not in the mapping 271 key, value = self.reference.popitem() 272 self.other = {key: value} 273 274 # A (key, value) pair in the mapping 275 key, value = self.reference.popitem() 276 self.inmapping = {key: value} 277 self.reference[key] = value 278 279 def test_read(self): 280 # Test for read only operations on mapping 281 p = self._empty_mapping() 282 p1 = dict(p) # workaround for singleton objects 283 d = self._full_mapping(self.reference) 284 if d is p: 285 p = p1 286 # Indexing 287 for key, value in self.reference.items(): 288 self.assertEqual(d[key], value) 289 knownkey = next(iter(self.other.keys())) 290 self.assertRaises(KeyError, lambda: d[knownkey]) 291 # len 292 self.assertEqual(len(p), 0) 293 self.assertEqual(len(d), len(self.reference)) 294 # __contains__ 295 for k in self.reference: 296 self.assertIn(k, d) 297 for k in self.other: 298 self.assertNotIn(k, d) 299 # cmp 300 self.assertTrue( 301 p == p 302 ) # NB: don't use assertEqual, that doesn't actually use == 303 self.assertTrue(d == d) 304 self.assertTrue(p != d) 305 self.assertTrue(d != p) 306 # bool 307 if p: 308 self.fail("Empty mapping must compare to False") 309 if not d: 310 self.fail("Full mapping must compare to True") 311 312 # keys(), items(), iterkeys() ... 313 def check_iterandlist(iter, lst, ref): 314 self.assertTrue(hasattr(iter, "__next__")) 315 self.assertTrue(hasattr(iter, "__iter__")) 316 x = list(iter) 317 self.assertTrue(set(x) == set(lst) == set(ref)) 318 319 check_iterandlist(iter(d.keys()), list(d.keys()), self.reference.keys()) 320 check_iterandlist(iter(d), list(d.keys()), self.reference.keys()) 321 check_iterandlist(iter(d.values()), list(d.values()), self.reference.values()) 322 check_iterandlist(iter(d.items()), list(d.items()), self.reference.items()) 323 # get 324 key, value = next(iter(d.items())) 325 knownkey, knownvalue = next(iter(self.other.items())) 326 self.assertEqual(d.get(key, knownvalue), value) 327 self.assertEqual(d.get(knownkey, knownvalue), knownvalue) 328 self.assertNotIn(knownkey, d) 329 330 def test_write(self): 331 # Test for write operations on mapping 332 p = self._empty_mapping() 333 # Indexing 334 for key, value in self.reference.items(): 335 p[key] = value 336 self.assertEqual(p[key], value) 337 for key in self.reference.keys(): 338 del p[key] 339 self.assertRaises(KeyError, lambda: p[key]) 340 p = self._empty_mapping() 341 # update 342 p.update(self.reference) 343 self.assertEqual(dict(p), self.reference) 344 items = list(p.items()) 345 p = self._empty_mapping() 346 p.update(items) 347 self.assertEqual(dict(p), self.reference) 348 d = self._full_mapping(self.reference) 349 # setdefault 350 key, value = next(iter(d.items())) 351 knownkey, knownvalue = next(iter(self.other.items())) 352 self.assertEqual(d.setdefault(key, knownvalue), value) 353 self.assertEqual(d[key], value) 354 self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue) 355 self.assertEqual(d[knownkey], knownvalue) 356 # pop 357 self.assertEqual(d.pop(knownkey), knownvalue) 358 self.assertNotIn(knownkey, d) 359 self.assertRaises(KeyError, d.pop, knownkey) 360 default = 909 361 d[knownkey] = knownvalue 362 self.assertEqual(d.pop(knownkey, default), knownvalue) 363 self.assertNotIn(knownkey, d) 364 self.assertEqual(d.pop(knownkey, default), default) 365 # popitem 366 key, value = d.popitem() 367 self.assertNotIn(key, d) 368 self.assertEqual(value, self.reference[key]) 369 p = self._empty_mapping() 370 self.assertRaises(KeyError, p.popitem) 371 372 def test_constructor(self): 373 self.assertEqual(self._empty_mapping(), self._empty_mapping()) 374 375 def test_bool(self): 376 self.assertTrue(not self._empty_mapping()) 377 self.assertTrue(self.reference) 378 self.assertTrue(bool(self._empty_mapping()) is False) 379 self.assertTrue(bool(self.reference) is True) 380 381 def test_keys(self): 382 d = self._empty_mapping() 383 self.assertEqual(list(d.keys()), []) 384 d = self.reference 385 self.assertIn(next(iter(self.inmapping.keys())), d.keys()) 386 self.assertNotIn(next(iter(self.other.keys())), d.keys()) 387 self.assertRaises(TypeError, d.keys, None) 388 389 def test_values(self): 390 d = self._empty_mapping() 391 self.assertEqual(list(d.values()), []) 392 393 self.assertRaises(TypeError, d.values, None) 394 395 def test_items(self): 396 d = self._empty_mapping() 397 self.assertEqual(list(d.items()), []) 398 399 self.assertRaises(TypeError, d.items, None) 400 401 def test_len(self): 402 d = self._empty_mapping() 403 self.assertEqual(len(d), 0) 404 405 def test_getitem(self): 406 d = self.reference 407 self.assertEqual( 408 d[next(iter(self.inmapping.keys()))], next(iter(self.inmapping.values())) 409 ) 410 411 self.assertRaises(TypeError, d.__getitem__) 412 413 def test_update(self): 414 # mapping argument 415 d = self._empty_mapping() 416 d.update(self.other) 417 self.assertEqual(list(d.items()), list(self.other.items())) 418 419 # No argument 420 d = self._empty_mapping() 421 d.update() 422 self.assertEqual(d, self._empty_mapping()) 423 424 # item sequence 425 d = self._empty_mapping() 426 d.update(self.other.items()) 427 self.assertEqual(list(d.items()), list(self.other.items())) 428 429 # Iterator 430 d = self._empty_mapping() 431 d.update(self.other.items()) 432 self.assertEqual(list(d.items()), list(self.other.items())) 433 434 # FIXME: Doesn't work with UserDict 435 # self.assertRaises((TypeError, AttributeError), d.update, None) 436 self.assertRaises((TypeError, AttributeError), d.update, 42) 437 438 outerself = self 439 440 class SimpleUserDict: 441 def __init__(self) -> None: 442 self.d = outerself.reference 443 444 def keys(self): 445 return self.d.keys() 446 447 def __getitem__(self, i): 448 return self.d[i] 449 450 d.clear() 451 d.update(SimpleUserDict()) 452 i1 = sorted((id(k), v) for k, v in d.items()) 453 i2 = sorted((id(k), v) for k, v in self.reference.items()) 454 self.assertEqual(i1, i2) 455 456 class Exc(Exception): 457 pass 458 459 d = self._empty_mapping() 460 461 class FailingUserDict: 462 def keys(self): 463 raise Exc 464 465 self.assertRaises(Exc, d.update, FailingUserDict()) 466 467 d.clear() 468 469 class FailingUserDict: 470 def keys(self): 471 class BogonIter: 472 def __init__(self) -> None: 473 self.i = 1 474 475 def __iter__(self): 476 return self 477 478 def __next__(self): 479 if self.i: 480 self.i = 0 481 return "a" 482 raise Exc 483 484 return BogonIter() 485 486 def __getitem__(self, key): 487 return key 488 489 self.assertRaises(Exc, d.update, FailingUserDict()) 490 491 class FailingUserDict: 492 def keys(self): 493 class BogonIter: 494 def __init__(self) -> None: 495 self.i = ord("a") 496 497 def __iter__(self): 498 return self 499 500 def __next__(self): 501 if self.i <= ord("z"): 502 rtn = chr(self.i) 503 self.i += 1 504 return rtn 505 raise StopIteration 506 507 return BogonIter() 508 509 def __getitem__(self, key): 510 raise Exc 511 512 self.assertRaises(Exc, d.update, FailingUserDict()) 513 514 d = self._empty_mapping() 515 516 class badseq: 517 def __iter__(self): 518 return self 519 520 def __next__(self): 521 raise Exc 522 523 self.assertRaises(Exc, d.update, badseq()) 524 525 self.assertRaises(ValueError, d.update, [(1, 2, 3)]) 526 527 # no test_fromkeys or test_copy as both os.environ and selves don't support it 528 529 def test_get(self): 530 d = self._empty_mapping() 531 self.assertTrue(d.get(next(iter(self.other.keys()))) is None) 532 self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3) 533 d = self.reference 534 self.assertTrue(d.get(next(iter(self.other.keys()))) is None) 535 self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3) 536 self.assertEqual( 537 d.get(next(iter(self.inmapping.keys()))), 538 next(iter(self.inmapping.values())), 539 ) 540 self.assertEqual( 541 d.get(next(iter(self.inmapping.keys())), 3), 542 next(iter(self.inmapping.values())), 543 ) 544 self.assertRaises(TypeError, d.get) 545 self.assertRaises(TypeError, d.get, None, None, None) 546 547 def test_setdefault(self): 548 d = self._empty_mapping() 549 self.assertRaises(TypeError, d.setdefault) 550 551 def test_popitem(self): 552 d = self._empty_mapping() 553 self.assertRaises(KeyError, d.popitem) 554 self.assertRaises(TypeError, d.popitem, 42) 555 556 def test_pop(self): 557 d = self._empty_mapping() 558 k, v = next(iter(self.inmapping.items())) 559 d[k] = v 560 self.assertRaises(KeyError, d.pop, next(iter(self.other.keys()))) 561 562 self.assertEqual(d.pop(k), v) 563 self.assertEqual(len(d), 0) 564 565 self.assertRaises(KeyError, d.pop, k) 566 567 568# Adapted from cpython/Lib/test/mapping_tests.py 569class WeakKeyDictionaryScriptObjectTestCase(TestCase): 570 def _reference(self): 571 self.__ref = { 572 torch.classes._TorchScriptTesting._Foo(1, 2): 1, 573 torch.classes._TorchScriptTesting._Foo(2, 3): 2, 574 torch.classes._TorchScriptTesting._Foo(3, 4): 3, 575 } 576 return self.__ref.copy() 577 578 def _empty_mapping(self): 579 """Return an empty mapping object""" 580 return WeakIdKeyDictionary(ref_type=_WeakHashRef) 581 582 def _full_mapping(self, data): 583 """Return a mapping object with the value contained in data 584 dictionary""" 585 x = self._empty_mapping() 586 for key, value in data.items(): 587 x[key] = value 588 return x 589 590 def setUp(self): 591 if IS_MACOS: 592 raise unittest.SkipTest("non-portable load_library call used in test") 593 594 def __init__(self, *args, **kw): 595 unittest.TestCase.__init__(self, *args, **kw) 596 if IS_SANDCASTLE or IS_FBCODE: 597 torch.ops.load_library( 598 "//caffe2/test/cpp/jit:test_custom_class_registrations" 599 ) 600 elif IS_MACOS: 601 # don't load the library, just skip the tests in setUp 602 return 603 else: 604 lib_file_path = find_library_location("libtorchbind_test.so") 605 if IS_WINDOWS: 606 lib_file_path = find_library_location("torchbind_test.dll") 607 torch.ops.load_library(str(lib_file_path)) 608 609 self.reference = self._reference().copy() 610 611 # A (key, value) pair not in the mapping 612 key, value = self.reference.popitem() 613 self.other = {key: value} 614 615 # A (key, value) pair in the mapping 616 key, value = self.reference.popitem() 617 self.inmapping = {key: value} 618 self.reference[key] = value 619 620 def test_read(self): 621 # Test for read only operations on mapping 622 p = self._empty_mapping() 623 p1 = dict(p) # workaround for singleton objects 624 d = self._full_mapping(self.reference) 625 if d is p: 626 p = p1 627 # Indexing 628 for key, value in self.reference.items(): 629 self.assertEqual(d[key], value) 630 knownkey = next(iter(self.other.keys())) 631 self.assertRaises(KeyError, lambda: d[knownkey]) 632 # len 633 self.assertEqual(len(p), 0) 634 self.assertEqual(len(d), len(self.reference)) 635 # __contains__ 636 for k in self.reference: 637 self.assertIn(k, d) 638 for k in self.other: 639 self.assertNotIn(k, d) 640 # cmp 641 self.assertTrue( 642 p == p 643 ) # NB: don't use assertEqual, that doesn't actually use == 644 self.assertTrue(d == d) 645 self.assertTrue(p != d) 646 self.assertTrue(d != p) 647 # bool 648 if p: 649 self.fail("Empty mapping must compare to False") 650 if not d: 651 self.fail("Full mapping must compare to True") 652 653 # keys(), items(), iterkeys() ... 654 def check_iterandlist(iter, lst, ref): 655 self.assertTrue(hasattr(iter, "__next__")) 656 self.assertTrue(hasattr(iter, "__iter__")) 657 x = list(iter) 658 self.assertTrue(set(x) == set(lst) == set(ref)) 659 660 check_iterandlist(iter(d.keys()), list(d.keys()), self.reference.keys()) 661 check_iterandlist(iter(d), list(d.keys()), self.reference.keys()) 662 check_iterandlist(iter(d.values()), list(d.values()), self.reference.values()) 663 check_iterandlist(iter(d.items()), list(d.items()), self.reference.items()) 664 # get 665 key, value = next(iter(d.items())) 666 knownkey, knownvalue = next(iter(self.other.items())) 667 self.assertEqual(d.get(key, knownvalue), value) 668 self.assertEqual(d.get(knownkey, knownvalue), knownvalue) 669 self.assertNotIn(knownkey, d) 670 671 def test_write(self): 672 # Test for write operations on mapping 673 p = self._empty_mapping() 674 # Indexing 675 for key, value in self.reference.items(): 676 p[key] = value 677 self.assertEqual(p[key], value) 678 for key in self.reference.keys(): 679 del p[key] 680 self.assertRaises(KeyError, lambda: p[key]) 681 p = self._empty_mapping() 682 # update 683 p.update(self.reference) 684 self.assertEqual(dict(p), self.reference) 685 items = list(p.items()) 686 p = self._empty_mapping() 687 p.update(items) 688 self.assertEqual(dict(p), self.reference) 689 d = self._full_mapping(self.reference) 690 # setdefault 691 key, value = next(iter(d.items())) 692 knownkey, knownvalue = next(iter(self.other.items())) 693 self.assertEqual(d.setdefault(key, knownvalue), value) 694 self.assertEqual(d[key], value) 695 self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue) 696 self.assertEqual(d[knownkey], knownvalue) 697 # pop 698 self.assertEqual(d.pop(knownkey), knownvalue) 699 self.assertNotIn(knownkey, d) 700 self.assertRaises(KeyError, d.pop, knownkey) 701 default = 909 702 d[knownkey] = knownvalue 703 self.assertEqual(d.pop(knownkey, default), knownvalue) 704 self.assertNotIn(knownkey, d) 705 self.assertEqual(d.pop(knownkey, default), default) 706 # popitem 707 key, value = d.popitem() 708 self.assertNotIn(key, d) 709 self.assertEqual(value, self.reference[key]) 710 p = self._empty_mapping() 711 self.assertRaises(KeyError, p.popitem) 712 713 def test_constructor(self): 714 self.assertEqual(self._empty_mapping(), self._empty_mapping()) 715 716 def test_bool(self): 717 self.assertTrue(not self._empty_mapping()) 718 self.assertTrue(self.reference) 719 self.assertTrue(bool(self._empty_mapping()) is False) 720 self.assertTrue(bool(self.reference) is True) 721 722 def test_keys(self): 723 d = self._empty_mapping() 724 self.assertEqual(list(d.keys()), []) 725 d = self.reference 726 self.assertIn(next(iter(self.inmapping.keys())), d.keys()) 727 self.assertNotIn(next(iter(self.other.keys())), d.keys()) 728 self.assertRaises(TypeError, d.keys, None) 729 730 def test_values(self): 731 d = self._empty_mapping() 732 self.assertEqual(list(d.values()), []) 733 734 self.assertRaises(TypeError, d.values, None) 735 736 def test_items(self): 737 d = self._empty_mapping() 738 self.assertEqual(list(d.items()), []) 739 740 self.assertRaises(TypeError, d.items, None) 741 742 def test_len(self): 743 d = self._empty_mapping() 744 self.assertEqual(len(d), 0) 745 746 def test_getitem(self): 747 d = self.reference 748 self.assertEqual( 749 d[next(iter(self.inmapping.keys()))], next(iter(self.inmapping.values())) 750 ) 751 752 self.assertRaises(TypeError, d.__getitem__) 753 754 def test_update(self): 755 # mapping argument 756 d = self._empty_mapping() 757 d.update(self.other) 758 self.assertEqual(list(d.items()), list(self.other.items())) 759 760 # No argument 761 d = self._empty_mapping() 762 d.update() 763 self.assertEqual(d, self._empty_mapping()) 764 765 # item sequence 766 d = self._empty_mapping() 767 d.update(self.other.items()) 768 self.assertEqual(list(d.items()), list(self.other.items())) 769 770 # Iterator 771 d = self._empty_mapping() 772 d.update(self.other.items()) 773 self.assertEqual(list(d.items()), list(self.other.items())) 774 775 # FIXME: Doesn't work with UserDict 776 # self.assertRaises((TypeError, AttributeError), d.update, None) 777 self.assertRaises((TypeError, AttributeError), d.update, 42) 778 779 outerself = self 780 781 class SimpleUserDict: 782 def __init__(self) -> None: 783 self.d = outerself.reference 784 785 def keys(self): 786 return self.d.keys() 787 788 def __getitem__(self, i): 789 return self.d[i] 790 791 d.clear() 792 d.update(SimpleUserDict()) 793 i1 = sorted((id(k), v) for k, v in d.items()) 794 i2 = sorted((id(k), v) for k, v in self.reference.items()) 795 self.assertEqual(i1, i2) 796 797 class Exc(Exception): 798 pass 799 800 d = self._empty_mapping() 801 802 class FailingUserDict: 803 def keys(self): 804 raise Exc 805 806 self.assertRaises(Exc, d.update, FailingUserDict()) 807 808 d.clear() 809 810 class FailingUserDict: 811 def keys(self): 812 class BogonIter: 813 def __init__(self) -> None: 814 self.i = 1 815 816 def __iter__(self): 817 return self 818 819 def __next__(self): 820 if self.i: 821 self.i = 0 822 return "a" 823 raise Exc 824 825 return BogonIter() 826 827 def __getitem__(self, key): 828 return key 829 830 self.assertRaises(Exc, d.update, FailingUserDict()) 831 832 class FailingUserDict: 833 def keys(self): 834 class BogonIter: 835 def __init__(self) -> None: 836 self.i = ord("a") 837 838 def __iter__(self): 839 return self 840 841 def __next__(self): 842 if self.i <= ord("z"): 843 rtn = chr(self.i) 844 self.i += 1 845 return rtn 846 raise StopIteration 847 848 return BogonIter() 849 850 def __getitem__(self, key): 851 raise Exc 852 853 self.assertRaises(Exc, d.update, FailingUserDict()) 854 855 d = self._empty_mapping() 856 857 class badseq: 858 def __iter__(self): 859 return self 860 861 def __next__(self): 862 raise Exc 863 864 self.assertRaises(Exc, d.update, badseq()) 865 866 self.assertRaises(ValueError, d.update, [(1, 2, 3)]) 867 868 # no test_fromkeys or test_copy as both os.environ and selves don't support it 869 870 def test_get(self): 871 d = self._empty_mapping() 872 self.assertTrue(d.get(next(iter(self.other.keys()))) is None) 873 self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3) 874 d = self.reference 875 self.assertTrue(d.get(next(iter(self.other.keys()))) is None) 876 self.assertEqual(d.get(next(iter(self.other.keys())), 3), 3) 877 self.assertEqual( 878 d.get(next(iter(self.inmapping.keys()))), 879 next(iter(self.inmapping.values())), 880 ) 881 self.assertEqual( 882 d.get(next(iter(self.inmapping.keys())), 3), 883 next(iter(self.inmapping.values())), 884 ) 885 self.assertRaises(TypeError, d.get) 886 self.assertRaises(TypeError, d.get, None, None, None) 887 888 def test_setdefault(self): 889 d = self._empty_mapping() 890 self.assertRaises(TypeError, d.setdefault) 891 892 def test_popitem(self): 893 d = self._empty_mapping() 894 self.assertRaises(KeyError, d.popitem) 895 self.assertRaises(TypeError, d.popitem, 42) 896 897 def test_pop(self): 898 d = self._empty_mapping() 899 k, v = next(iter(self.inmapping.items())) 900 d[k] = v 901 self.assertRaises(KeyError, d.pop, next(iter(self.other.keys()))) 902 903 self.assertEqual(d.pop(k), v) 904 self.assertEqual(len(d), 0) 905 906 self.assertRaises(KeyError, d.pop, k) 907 908 909if __name__ == "__main__": 910 run_tests() 911