1import asyncio 2import gc 3import inspect 4import re 5import unittest 6from contextlib import contextmanager 7from test import support 8 9support.requires_working_socket(module=True) 10 11from asyncio import run, iscoroutinefunction 12from unittest import IsolatedAsyncioTestCase 13from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock, Mock, 14 create_autospec, sentinel, _CallList, seal) 15 16 17def tearDownModule(): 18 asyncio.set_event_loop_policy(None) 19 20 21class AsyncClass: 22 def __init__(self): pass 23 async def async_method(self): pass 24 def normal_method(self): pass 25 26 @classmethod 27 async def async_class_method(cls): pass 28 29 @staticmethod 30 async def async_static_method(): pass 31 32 33class AwaitableClass: 34 def __await__(self): yield 35 36async def async_func(): pass 37 38async def async_func_args(a, b, *, c): pass 39 40def normal_func(): pass 41 42class NormalClass(object): 43 def a(self): pass 44 45 46async_foo_name = f'{__name__}.AsyncClass' 47normal_foo_name = f'{__name__}.NormalClass' 48 49 50@contextmanager 51def assertNeverAwaited(test): 52 with test.assertWarnsRegex(RuntimeWarning, "was never awaited$"): 53 yield 54 # In non-CPython implementations of Python, this is needed because timely 55 # deallocation is not guaranteed by the garbage collector. 56 gc.collect() 57 58 59class AsyncPatchDecoratorTest(unittest.TestCase): 60 def test_is_coroutine_function_patch(self): 61 @patch.object(AsyncClass, 'async_method') 62 def test_async(mock_method): 63 self.assertTrue(iscoroutinefunction(mock_method)) 64 test_async() 65 66 def test_is_async_patch(self): 67 @patch.object(AsyncClass, 'async_method') 68 def test_async(mock_method): 69 m = mock_method() 70 self.assertTrue(inspect.isawaitable(m)) 71 run(m) 72 73 @patch(f'{async_foo_name}.async_method') 74 def test_no_parent_attribute(mock_method): 75 m = mock_method() 76 self.assertTrue(inspect.isawaitable(m)) 77 run(m) 78 79 test_async() 80 test_no_parent_attribute() 81 82 def test_is_AsyncMock_patch(self): 83 @patch.object(AsyncClass, 'async_method') 84 def test_async(mock_method): 85 self.assertIsInstance(mock_method, AsyncMock) 86 87 test_async() 88 89 def test_is_AsyncMock_patch_staticmethod(self): 90 @patch.object(AsyncClass, 'async_static_method') 91 def test_async(mock_method): 92 self.assertIsInstance(mock_method, AsyncMock) 93 94 test_async() 95 96 def test_is_AsyncMock_patch_classmethod(self): 97 @patch.object(AsyncClass, 'async_class_method') 98 def test_async(mock_method): 99 self.assertIsInstance(mock_method, AsyncMock) 100 101 test_async() 102 103 def test_async_def_patch(self): 104 @patch(f"{__name__}.async_func", return_value=1) 105 @patch(f"{__name__}.async_func_args", return_value=2) 106 async def test_async(func_args_mock, func_mock): 107 self.assertEqual(func_args_mock._mock_name, "async_func_args") 108 self.assertEqual(func_mock._mock_name, "async_func") 109 110 self.assertIsInstance(async_func, AsyncMock) 111 self.assertIsInstance(async_func_args, AsyncMock) 112 113 self.assertEqual(await async_func(), 1) 114 self.assertEqual(await async_func_args(1, 2, c=3), 2) 115 116 run(test_async()) 117 self.assertTrue(inspect.iscoroutinefunction(async_func)) 118 119 120class AsyncPatchCMTest(unittest.TestCase): 121 def test_is_async_function_cm(self): 122 def test_async(): 123 with patch.object(AsyncClass, 'async_method') as mock_method: 124 self.assertTrue(iscoroutinefunction(mock_method)) 125 126 test_async() 127 128 def test_is_async_cm(self): 129 def test_async(): 130 with patch.object(AsyncClass, 'async_method') as mock_method: 131 m = mock_method() 132 self.assertTrue(inspect.isawaitable(m)) 133 run(m) 134 135 test_async() 136 137 def test_is_AsyncMock_cm(self): 138 def test_async(): 139 with patch.object(AsyncClass, 'async_method') as mock_method: 140 self.assertIsInstance(mock_method, AsyncMock) 141 142 test_async() 143 144 def test_async_def_cm(self): 145 async def test_async(): 146 with patch(f"{__name__}.async_func", AsyncMock()): 147 self.assertIsInstance(async_func, AsyncMock) 148 self.assertTrue(inspect.iscoroutinefunction(async_func)) 149 150 run(test_async()) 151 152 def test_patch_dict_async_def(self): 153 foo = {'a': 'a'} 154 @patch.dict(foo, {'a': 'b'}) 155 async def test_async(): 156 self.assertEqual(foo['a'], 'b') 157 158 self.assertTrue(iscoroutinefunction(test_async)) 159 run(test_async()) 160 161 def test_patch_dict_async_def_context(self): 162 foo = {'a': 'a'} 163 async def test_async(): 164 with patch.dict(foo, {'a': 'b'}): 165 self.assertEqual(foo['a'], 'b') 166 167 run(test_async()) 168 169 170class AsyncMockTest(unittest.TestCase): 171 def test_iscoroutinefunction_default(self): 172 mock = AsyncMock() 173 self.assertTrue(iscoroutinefunction(mock)) 174 175 def test_iscoroutinefunction_function(self): 176 async def foo(): pass 177 mock = AsyncMock(foo) 178 self.assertTrue(iscoroutinefunction(mock)) 179 self.assertTrue(inspect.iscoroutinefunction(mock)) 180 181 def test_isawaitable(self): 182 mock = AsyncMock() 183 m = mock() 184 self.assertTrue(inspect.isawaitable(m)) 185 run(m) 186 self.assertIn('assert_awaited', dir(mock)) 187 188 def test_iscoroutinefunction_normal_function(self): 189 def foo(): pass 190 mock = AsyncMock(foo) 191 self.assertTrue(iscoroutinefunction(mock)) 192 self.assertTrue(inspect.iscoroutinefunction(mock)) 193 194 def test_future_isfuture(self): 195 loop = asyncio.new_event_loop() 196 fut = loop.create_future() 197 loop.stop() 198 loop.close() 199 mock = AsyncMock(fut) 200 self.assertIsInstance(mock, asyncio.Future) 201 202 203class AsyncAutospecTest(unittest.TestCase): 204 def test_is_AsyncMock_patch(self): 205 @patch(async_foo_name, autospec=True) 206 def test_async(mock_method): 207 self.assertIsInstance(mock_method.async_method, AsyncMock) 208 self.assertIsInstance(mock_method, MagicMock) 209 210 @patch(async_foo_name, autospec=True) 211 def test_normal_method(mock_method): 212 self.assertIsInstance(mock_method.normal_method, MagicMock) 213 214 test_async() 215 test_normal_method() 216 217 def test_create_autospec_instance(self): 218 with self.assertRaises(RuntimeError): 219 create_autospec(async_func, instance=True) 220 221 @unittest.skip('Broken test from https://bugs.python.org/issue37251') 222 def test_create_autospec_awaitable_class(self): 223 self.assertIsInstance(create_autospec(AwaitableClass), AsyncMock) 224 225 def test_create_autospec(self): 226 spec = create_autospec(async_func_args) 227 awaitable = spec(1, 2, c=3) 228 async def main(): 229 await awaitable 230 231 self.assertEqual(spec.await_count, 0) 232 self.assertIsNone(spec.await_args) 233 self.assertEqual(spec.await_args_list, []) 234 spec.assert_not_awaited() 235 236 run(main()) 237 238 self.assertTrue(iscoroutinefunction(spec)) 239 self.assertTrue(asyncio.iscoroutine(awaitable)) 240 self.assertEqual(spec.await_count, 1) 241 self.assertEqual(spec.await_args, call(1, 2, c=3)) 242 self.assertEqual(spec.await_args_list, [call(1, 2, c=3)]) 243 spec.assert_awaited_once() 244 spec.assert_awaited_once_with(1, 2, c=3) 245 spec.assert_awaited_with(1, 2, c=3) 246 spec.assert_awaited() 247 248 with self.assertRaises(AssertionError): 249 spec.assert_any_await(e=1) 250 251 252 def test_patch_with_autospec(self): 253 254 async def test_async(): 255 with patch(f"{__name__}.async_func_args", autospec=True) as mock_method: 256 awaitable = mock_method(1, 2, c=3) 257 self.assertIsInstance(mock_method.mock, AsyncMock) 258 259 self.assertTrue(iscoroutinefunction(mock_method)) 260 self.assertTrue(asyncio.iscoroutine(awaitable)) 261 self.assertTrue(inspect.isawaitable(awaitable)) 262 263 # Verify the default values during mock setup 264 self.assertEqual(mock_method.await_count, 0) 265 self.assertEqual(mock_method.await_args_list, []) 266 self.assertIsNone(mock_method.await_args) 267 mock_method.assert_not_awaited() 268 269 await awaitable 270 271 self.assertEqual(mock_method.await_count, 1) 272 self.assertEqual(mock_method.await_args, call(1, 2, c=3)) 273 self.assertEqual(mock_method.await_args_list, [call(1, 2, c=3)]) 274 mock_method.assert_awaited_once() 275 mock_method.assert_awaited_once_with(1, 2, c=3) 276 mock_method.assert_awaited_with(1, 2, c=3) 277 mock_method.assert_awaited() 278 279 mock_method.reset_mock() 280 self.assertEqual(mock_method.await_count, 0) 281 self.assertIsNone(mock_method.await_args) 282 self.assertEqual(mock_method.await_args_list, []) 283 284 run(test_async()) 285 286 287class AsyncSpecTest(unittest.TestCase): 288 def test_spec_normal_methods_on_class(self): 289 def inner_test(mock_type): 290 mock = mock_type(AsyncClass) 291 self.assertIsInstance(mock.async_method, AsyncMock) 292 self.assertIsInstance(mock.normal_method, MagicMock) 293 294 for mock_type in [AsyncMock, MagicMock]: 295 with self.subTest(f"test method types with {mock_type}"): 296 inner_test(mock_type) 297 298 def test_spec_normal_methods_on_class_with_mock(self): 299 mock = Mock(AsyncClass) 300 self.assertIsInstance(mock.async_method, AsyncMock) 301 self.assertIsInstance(mock.normal_method, Mock) 302 303 def test_spec_normal_methods_on_class_with_mock_seal(self): 304 mock = Mock(AsyncClass) 305 seal(mock) 306 with self.assertRaises(AttributeError): 307 mock.normal_method 308 with self.assertRaises(AttributeError): 309 mock.async_method 310 311 def test_spec_mock_type_kw(self): 312 def inner_test(mock_type): 313 async_mock = mock_type(spec=async_func) 314 self.assertIsInstance(async_mock, mock_type) 315 with assertNeverAwaited(self): 316 self.assertTrue(inspect.isawaitable(async_mock())) 317 318 sync_mock = mock_type(spec=normal_func) 319 self.assertIsInstance(sync_mock, mock_type) 320 321 for mock_type in [AsyncMock, MagicMock, Mock]: 322 with self.subTest(f"test spec kwarg with {mock_type}"): 323 inner_test(mock_type) 324 325 def test_spec_mock_type_positional(self): 326 def inner_test(mock_type): 327 async_mock = mock_type(async_func) 328 self.assertIsInstance(async_mock, mock_type) 329 with assertNeverAwaited(self): 330 self.assertTrue(inspect.isawaitable(async_mock())) 331 332 sync_mock = mock_type(normal_func) 333 self.assertIsInstance(sync_mock, mock_type) 334 335 for mock_type in [AsyncMock, MagicMock, Mock]: 336 with self.subTest(f"test spec positional with {mock_type}"): 337 inner_test(mock_type) 338 339 def test_spec_as_normal_kw_AsyncMock(self): 340 mock = AsyncMock(spec=normal_func) 341 self.assertIsInstance(mock, AsyncMock) 342 m = mock() 343 self.assertTrue(inspect.isawaitable(m)) 344 run(m) 345 346 def test_spec_as_normal_positional_AsyncMock(self): 347 mock = AsyncMock(normal_func) 348 self.assertIsInstance(mock, AsyncMock) 349 m = mock() 350 self.assertTrue(inspect.isawaitable(m)) 351 run(m) 352 353 def test_spec_async_mock(self): 354 @patch.object(AsyncClass, 'async_method', spec=True) 355 def test_async(mock_method): 356 self.assertIsInstance(mock_method, AsyncMock) 357 358 test_async() 359 360 def test_spec_parent_not_async_attribute_is(self): 361 @patch(async_foo_name, spec=True) 362 def test_async(mock_method): 363 self.assertIsInstance(mock_method, MagicMock) 364 self.assertIsInstance(mock_method.async_method, AsyncMock) 365 366 test_async() 367 368 def test_target_async_spec_not(self): 369 @patch.object(AsyncClass, 'async_method', spec=NormalClass.a) 370 def test_async_attribute(mock_method): 371 self.assertIsInstance(mock_method, MagicMock) 372 self.assertFalse(inspect.iscoroutine(mock_method)) 373 self.assertFalse(inspect.isawaitable(mock_method)) 374 375 test_async_attribute() 376 377 def test_target_not_async_spec_is(self): 378 @patch.object(NormalClass, 'a', spec=async_func) 379 def test_attribute_not_async_spec_is(mock_async_func): 380 self.assertIsInstance(mock_async_func, AsyncMock) 381 test_attribute_not_async_spec_is() 382 383 def test_spec_async_attributes(self): 384 @patch(normal_foo_name, spec=AsyncClass) 385 def test_async_attributes_coroutines(MockNormalClass): 386 self.assertIsInstance(MockNormalClass.async_method, AsyncMock) 387 self.assertIsInstance(MockNormalClass, MagicMock) 388 389 test_async_attributes_coroutines() 390 391 392class AsyncSpecSetTest(unittest.TestCase): 393 def test_is_AsyncMock_patch(self): 394 @patch.object(AsyncClass, 'async_method', spec_set=True) 395 def test_async(async_method): 396 self.assertIsInstance(async_method, AsyncMock) 397 test_async() 398 399 def test_is_async_AsyncMock(self): 400 mock = AsyncMock(spec_set=AsyncClass.async_method) 401 self.assertTrue(iscoroutinefunction(mock)) 402 self.assertIsInstance(mock, AsyncMock) 403 404 def test_is_child_AsyncMock(self): 405 mock = MagicMock(spec_set=AsyncClass) 406 self.assertTrue(iscoroutinefunction(mock.async_method)) 407 self.assertFalse(iscoroutinefunction(mock.normal_method)) 408 self.assertIsInstance(mock.async_method, AsyncMock) 409 self.assertIsInstance(mock.normal_method, MagicMock) 410 self.assertIsInstance(mock, MagicMock) 411 412 def test_magicmock_lambda_spec(self): 413 mock_obj = MagicMock() 414 mock_obj.mock_func = MagicMock(spec=lambda x: x) 415 416 with patch.object(mock_obj, "mock_func") as cm: 417 self.assertIsInstance(cm, MagicMock) 418 419 420class AsyncArguments(IsolatedAsyncioTestCase): 421 async def test_add_return_value(self): 422 async def addition(self, var): pass 423 424 mock = AsyncMock(addition, return_value=10) 425 output = await mock(5) 426 427 self.assertEqual(output, 10) 428 429 async def test_add_side_effect_exception(self): 430 async def addition(var): pass 431 mock = AsyncMock(addition, side_effect=Exception('err')) 432 with self.assertRaises(Exception): 433 await mock(5) 434 435 async def test_add_side_effect_coroutine(self): 436 async def addition(var): 437 return var + 1 438 mock = AsyncMock(side_effect=addition) 439 result = await mock(5) 440 self.assertEqual(result, 6) 441 442 async def test_add_side_effect_normal_function(self): 443 def addition(var): 444 return var + 1 445 mock = AsyncMock(side_effect=addition) 446 result = await mock(5) 447 self.assertEqual(result, 6) 448 449 async def test_add_side_effect_iterable(self): 450 vals = [1, 2, 3] 451 mock = AsyncMock(side_effect=vals) 452 for item in vals: 453 self.assertEqual(await mock(), item) 454 455 with self.assertRaises(StopAsyncIteration) as e: 456 await mock() 457 458 async def test_add_side_effect_exception_iterable(self): 459 class SampleException(Exception): 460 pass 461 462 vals = [1, SampleException("foo")] 463 mock = AsyncMock(side_effect=vals) 464 self.assertEqual(await mock(), 1) 465 466 with self.assertRaises(SampleException) as e: 467 await mock() 468 469 async def test_return_value_AsyncMock(self): 470 value = AsyncMock(return_value=10) 471 mock = AsyncMock(return_value=value) 472 result = await mock() 473 self.assertIs(result, value) 474 475 async def test_return_value_awaitable(self): 476 fut = asyncio.Future() 477 fut.set_result(None) 478 mock = AsyncMock(return_value=fut) 479 result = await mock() 480 self.assertIsInstance(result, asyncio.Future) 481 482 async def test_side_effect_awaitable_values(self): 483 fut = asyncio.Future() 484 fut.set_result(None) 485 486 mock = AsyncMock(side_effect=[fut]) 487 result = await mock() 488 self.assertIsInstance(result, asyncio.Future) 489 490 with self.assertRaises(StopAsyncIteration): 491 await mock() 492 493 async def test_side_effect_is_AsyncMock(self): 494 effect = AsyncMock(return_value=10) 495 mock = AsyncMock(side_effect=effect) 496 497 result = await mock() 498 self.assertEqual(result, 10) 499 500 async def test_wraps_coroutine(self): 501 value = asyncio.Future() 502 503 ran = False 504 async def inner(): 505 nonlocal ran 506 ran = True 507 return value 508 509 mock = AsyncMock(wraps=inner) 510 result = await mock() 511 self.assertEqual(result, value) 512 mock.assert_awaited() 513 self.assertTrue(ran) 514 515 async def test_wraps_normal_function(self): 516 value = 1 517 518 ran = False 519 def inner(): 520 nonlocal ran 521 ran = True 522 return value 523 524 mock = AsyncMock(wraps=inner) 525 result = await mock() 526 self.assertEqual(result, value) 527 mock.assert_awaited() 528 self.assertTrue(ran) 529 530 async def test_await_args_list_order(self): 531 async_mock = AsyncMock() 532 mock2 = async_mock(2) 533 mock1 = async_mock(1) 534 await mock1 535 await mock2 536 async_mock.assert_has_awaits([call(1), call(2)]) 537 self.assertEqual(async_mock.await_args_list, [call(1), call(2)]) 538 self.assertEqual(async_mock.call_args_list, [call(2), call(1)]) 539 540 541class AsyncMagicMethods(unittest.TestCase): 542 def test_async_magic_methods_return_async_mocks(self): 543 m_mock = MagicMock() 544 self.assertIsInstance(m_mock.__aenter__, AsyncMock) 545 self.assertIsInstance(m_mock.__aexit__, AsyncMock) 546 self.assertIsInstance(m_mock.__anext__, AsyncMock) 547 # __aiter__ is actually a synchronous object 548 # so should return a MagicMock 549 self.assertIsInstance(m_mock.__aiter__, MagicMock) 550 551 def test_sync_magic_methods_return_magic_mocks(self): 552 a_mock = AsyncMock() 553 self.assertIsInstance(a_mock.__enter__, MagicMock) 554 self.assertIsInstance(a_mock.__exit__, MagicMock) 555 self.assertIsInstance(a_mock.__next__, MagicMock) 556 self.assertIsInstance(a_mock.__len__, MagicMock) 557 558 def test_magicmock_has_async_magic_methods(self): 559 m_mock = MagicMock() 560 self.assertTrue(hasattr(m_mock, "__aenter__")) 561 self.assertTrue(hasattr(m_mock, "__aexit__")) 562 self.assertTrue(hasattr(m_mock, "__anext__")) 563 564 def test_asyncmock_has_sync_magic_methods(self): 565 a_mock = AsyncMock() 566 self.assertTrue(hasattr(a_mock, "__enter__")) 567 self.assertTrue(hasattr(a_mock, "__exit__")) 568 self.assertTrue(hasattr(a_mock, "__next__")) 569 self.assertTrue(hasattr(a_mock, "__len__")) 570 571 def test_magic_methods_are_async_functions(self): 572 m_mock = MagicMock() 573 self.assertIsInstance(m_mock.__aenter__, AsyncMock) 574 self.assertIsInstance(m_mock.__aexit__, AsyncMock) 575 # AsyncMocks are also coroutine functions 576 self.assertTrue(iscoroutinefunction(m_mock.__aenter__)) 577 self.assertTrue(iscoroutinefunction(m_mock.__aexit__)) 578 579class AsyncContextManagerTest(unittest.TestCase): 580 581 class WithAsyncContextManager: 582 async def __aenter__(self, *args, **kwargs): pass 583 584 async def __aexit__(self, *args, **kwargs): pass 585 586 class WithSyncContextManager: 587 def __enter__(self, *args, **kwargs): pass 588 589 def __exit__(self, *args, **kwargs): pass 590 591 class ProductionCode: 592 # Example real-world(ish) code 593 def __init__(self): 594 self.session = None 595 596 async def main(self): 597 async with self.session.post('https://python.org') as response: 598 val = await response.json() 599 return val 600 601 def test_set_return_value_of_aenter(self): 602 def inner_test(mock_type): 603 pc = self.ProductionCode() 604 pc.session = MagicMock(name='sessionmock') 605 cm = mock_type(name='magic_cm') 606 response = AsyncMock(name='response') 607 response.json = AsyncMock(return_value={'json': 123}) 608 cm.__aenter__.return_value = response 609 pc.session.post.return_value = cm 610 result = run(pc.main()) 611 self.assertEqual(result, {'json': 123}) 612 613 for mock_type in [AsyncMock, MagicMock]: 614 with self.subTest(f"test set return value of aenter with {mock_type}"): 615 inner_test(mock_type) 616 617 def test_mock_supports_async_context_manager(self): 618 def inner_test(mock_type): 619 called = False 620 cm = self.WithAsyncContextManager() 621 cm_mock = mock_type(cm) 622 623 async def use_context_manager(): 624 nonlocal called 625 async with cm_mock as result: 626 called = True 627 return result 628 629 cm_result = run(use_context_manager()) 630 self.assertTrue(called) 631 self.assertTrue(cm_mock.__aenter__.called) 632 self.assertTrue(cm_mock.__aexit__.called) 633 cm_mock.__aenter__.assert_awaited() 634 cm_mock.__aexit__.assert_awaited() 635 # We mock __aenter__ so it does not return self 636 self.assertIsNot(cm_mock, cm_result) 637 638 for mock_type in [AsyncMock, MagicMock]: 639 with self.subTest(f"test context manager magics with {mock_type}"): 640 inner_test(mock_type) 641 642 643 def test_mock_customize_async_context_manager(self): 644 instance = self.WithAsyncContextManager() 645 mock_instance = MagicMock(instance) 646 647 expected_result = object() 648 mock_instance.__aenter__.return_value = expected_result 649 650 async def use_context_manager(): 651 async with mock_instance as result: 652 return result 653 654 self.assertIs(run(use_context_manager()), expected_result) 655 656 def test_mock_customize_async_context_manager_with_coroutine(self): 657 enter_called = False 658 exit_called = False 659 660 async def enter_coroutine(*args): 661 nonlocal enter_called 662 enter_called = True 663 664 async def exit_coroutine(*args): 665 nonlocal exit_called 666 exit_called = True 667 668 instance = self.WithAsyncContextManager() 669 mock_instance = MagicMock(instance) 670 671 mock_instance.__aenter__ = enter_coroutine 672 mock_instance.__aexit__ = exit_coroutine 673 674 async def use_context_manager(): 675 async with mock_instance: 676 pass 677 678 run(use_context_manager()) 679 self.assertTrue(enter_called) 680 self.assertTrue(exit_called) 681 682 def test_context_manager_raise_exception_by_default(self): 683 async def raise_in(context_manager): 684 async with context_manager: 685 raise TypeError() 686 687 instance = self.WithAsyncContextManager() 688 mock_instance = MagicMock(instance) 689 with self.assertRaises(TypeError): 690 run(raise_in(mock_instance)) 691 692 693class AsyncIteratorTest(unittest.TestCase): 694 class WithAsyncIterator(object): 695 def __init__(self): 696 self.items = ["foo", "NormalFoo", "baz"] 697 698 def __aiter__(self): pass 699 700 async def __anext__(self): pass 701 702 def test_aiter_set_return_value(self): 703 mock_iter = AsyncMock(name="tester") 704 mock_iter.__aiter__.return_value = [1, 2, 3] 705 async def main(): 706 return [i async for i in mock_iter] 707 result = run(main()) 708 self.assertEqual(result, [1, 2, 3]) 709 710 def test_mock_aiter_and_anext_asyncmock(self): 711 def inner_test(mock_type): 712 instance = self.WithAsyncIterator() 713 mock_instance = mock_type(instance) 714 # Check that the mock and the real thing bahave the same 715 # __aiter__ is not actually async, so not a coroutinefunction 716 self.assertFalse(iscoroutinefunction(instance.__aiter__)) 717 self.assertFalse(iscoroutinefunction(mock_instance.__aiter__)) 718 # __anext__ is async 719 self.assertTrue(iscoroutinefunction(instance.__anext__)) 720 self.assertTrue(iscoroutinefunction(mock_instance.__anext__)) 721 722 for mock_type in [AsyncMock, MagicMock]: 723 with self.subTest(f"test aiter and anext corourtine with {mock_type}"): 724 inner_test(mock_type) 725 726 727 def test_mock_async_for(self): 728 async def iterate(iterator): 729 accumulator = [] 730 async for item in iterator: 731 accumulator.append(item) 732 733 return accumulator 734 735 expected = ["FOO", "BAR", "BAZ"] 736 def test_default(mock_type): 737 mock_instance = mock_type(self.WithAsyncIterator()) 738 self.assertEqual(run(iterate(mock_instance)), []) 739 740 741 def test_set_return_value(mock_type): 742 mock_instance = mock_type(self.WithAsyncIterator()) 743 mock_instance.__aiter__.return_value = expected[:] 744 self.assertEqual(run(iterate(mock_instance)), expected) 745 746 def test_set_return_value_iter(mock_type): 747 mock_instance = mock_type(self.WithAsyncIterator()) 748 mock_instance.__aiter__.return_value = iter(expected[:]) 749 self.assertEqual(run(iterate(mock_instance)), expected) 750 751 for mock_type in [AsyncMock, MagicMock]: 752 with self.subTest(f"default value with {mock_type}"): 753 test_default(mock_type) 754 755 with self.subTest(f"set return_value with {mock_type}"): 756 test_set_return_value(mock_type) 757 758 with self.subTest(f"set return_value iterator with {mock_type}"): 759 test_set_return_value_iter(mock_type) 760 761 762class AsyncMockAssert(unittest.TestCase): 763 def setUp(self): 764 self.mock = AsyncMock() 765 766 async def _runnable_test(self, *args, **kwargs): 767 await self.mock(*args, **kwargs) 768 769 async def _await_coroutine(self, coroutine): 770 return await coroutine 771 772 def test_assert_called_but_not_awaited(self): 773 mock = AsyncMock(AsyncClass) 774 with assertNeverAwaited(self): 775 mock.async_method() 776 self.assertTrue(iscoroutinefunction(mock.async_method)) 777 mock.async_method.assert_called() 778 mock.async_method.assert_called_once() 779 mock.async_method.assert_called_once_with() 780 with self.assertRaises(AssertionError): 781 mock.assert_awaited() 782 with self.assertRaises(AssertionError): 783 mock.async_method.assert_awaited() 784 785 def test_assert_called_then_awaited(self): 786 mock = AsyncMock(AsyncClass) 787 mock_coroutine = mock.async_method() 788 mock.async_method.assert_called() 789 mock.async_method.assert_called_once() 790 mock.async_method.assert_called_once_with() 791 with self.assertRaises(AssertionError): 792 mock.async_method.assert_awaited() 793 794 run(self._await_coroutine(mock_coroutine)) 795 # Assert we haven't re-called the function 796 mock.async_method.assert_called_once() 797 mock.async_method.assert_awaited() 798 mock.async_method.assert_awaited_once() 799 mock.async_method.assert_awaited_once_with() 800 801 def test_assert_called_and_awaited_at_same_time(self): 802 with self.assertRaises(AssertionError): 803 self.mock.assert_awaited() 804 805 with self.assertRaises(AssertionError): 806 self.mock.assert_called() 807 808 run(self._runnable_test()) 809 self.mock.assert_called_once() 810 self.mock.assert_awaited_once() 811 812 def test_assert_called_twice_and_awaited_once(self): 813 mock = AsyncMock(AsyncClass) 814 coroutine = mock.async_method() 815 # The first call will be awaited so no warning there 816 # But this call will never get awaited, so it will warn here 817 with assertNeverAwaited(self): 818 mock.async_method() 819 with self.assertRaises(AssertionError): 820 mock.async_method.assert_awaited() 821 mock.async_method.assert_called() 822 run(self._await_coroutine(coroutine)) 823 mock.async_method.assert_awaited() 824 mock.async_method.assert_awaited_once() 825 826 def test_assert_called_once_and_awaited_twice(self): 827 mock = AsyncMock(AsyncClass) 828 coroutine = mock.async_method() 829 mock.async_method.assert_called_once() 830 run(self._await_coroutine(coroutine)) 831 with self.assertRaises(RuntimeError): 832 # Cannot reuse already awaited coroutine 833 run(self._await_coroutine(coroutine)) 834 mock.async_method.assert_awaited() 835 836 def test_assert_awaited_but_not_called(self): 837 with self.assertRaises(AssertionError): 838 self.mock.assert_awaited() 839 with self.assertRaises(AssertionError): 840 self.mock.assert_called() 841 with self.assertRaises(TypeError): 842 # You cannot await an AsyncMock, it must be a coroutine 843 run(self._await_coroutine(self.mock)) 844 845 with self.assertRaises(AssertionError): 846 self.mock.assert_awaited() 847 with self.assertRaises(AssertionError): 848 self.mock.assert_called() 849 850 def test_assert_has_calls_not_awaits(self): 851 kalls = [call('foo')] 852 with assertNeverAwaited(self): 853 self.mock('foo') 854 self.mock.assert_has_calls(kalls) 855 with self.assertRaises(AssertionError): 856 self.mock.assert_has_awaits(kalls) 857 858 def test_assert_has_mock_calls_on_async_mock_no_spec(self): 859 with assertNeverAwaited(self): 860 self.mock() 861 kalls_empty = [('', (), {})] 862 self.assertEqual(self.mock.mock_calls, kalls_empty) 863 864 with assertNeverAwaited(self): 865 self.mock('foo') 866 with assertNeverAwaited(self): 867 self.mock('baz') 868 mock_kalls = ([call(), call('foo'), call('baz')]) 869 self.assertEqual(self.mock.mock_calls, mock_kalls) 870 871 def test_assert_has_mock_calls_on_async_mock_with_spec(self): 872 a_class_mock = AsyncMock(AsyncClass) 873 with assertNeverAwaited(self): 874 a_class_mock.async_method() 875 kalls_empty = [('', (), {})] 876 self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty) 877 self.assertEqual(a_class_mock.mock_calls, [call.async_method()]) 878 879 with assertNeverAwaited(self): 880 a_class_mock.async_method(1, 2, 3, a=4, b=5) 881 method_kalls = [call(), call(1, 2, 3, a=4, b=5)] 882 mock_kalls = [call.async_method(), call.async_method(1, 2, 3, a=4, b=5)] 883 self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls) 884 self.assertEqual(a_class_mock.mock_calls, mock_kalls) 885 886 def test_async_method_calls_recorded(self): 887 with assertNeverAwaited(self): 888 self.mock.something(3, fish=None) 889 with assertNeverAwaited(self): 890 self.mock.something_else.something(6, cake=sentinel.Cake) 891 892 self.assertEqual(self.mock.method_calls, [ 893 ("something", (3,), {'fish': None}), 894 ("something_else.something", (6,), {'cake': sentinel.Cake}) 895 ], 896 "method calls not recorded correctly") 897 self.assertEqual(self.mock.something_else.method_calls, 898 [("something", (6,), {'cake': sentinel.Cake})], 899 "method calls not recorded correctly") 900 901 def test_async_arg_lists(self): 902 def assert_attrs(mock): 903 names = ('call_args_list', 'method_calls', 'mock_calls') 904 for name in names: 905 attr = getattr(mock, name) 906 self.assertIsInstance(attr, _CallList) 907 self.assertIsInstance(attr, list) 908 self.assertEqual(attr, []) 909 910 assert_attrs(self.mock) 911 with assertNeverAwaited(self): 912 self.mock() 913 with assertNeverAwaited(self): 914 self.mock(1, 2) 915 with assertNeverAwaited(self): 916 self.mock(a=3) 917 918 self.mock.reset_mock() 919 assert_attrs(self.mock) 920 921 a_mock = AsyncMock(AsyncClass) 922 with assertNeverAwaited(self): 923 a_mock.async_method() 924 with assertNeverAwaited(self): 925 a_mock.async_method(1, a=3) 926 927 a_mock.reset_mock() 928 assert_attrs(a_mock) 929 930 def test_assert_awaited(self): 931 with self.assertRaises(AssertionError): 932 self.mock.assert_awaited() 933 934 run(self._runnable_test()) 935 self.mock.assert_awaited() 936 937 def test_assert_awaited_once(self): 938 with self.assertRaises(AssertionError): 939 self.mock.assert_awaited_once() 940 941 run(self._runnable_test()) 942 self.mock.assert_awaited_once() 943 944 run(self._runnable_test()) 945 with self.assertRaises(AssertionError): 946 self.mock.assert_awaited_once() 947 948 def test_assert_awaited_with(self): 949 msg = 'Not awaited' 950 with self.assertRaisesRegex(AssertionError, msg): 951 self.mock.assert_awaited_with('foo') 952 953 run(self._runnable_test()) 954 msg = 'expected await not found' 955 with self.assertRaisesRegex(AssertionError, msg): 956 self.mock.assert_awaited_with('foo') 957 958 run(self._runnable_test('foo')) 959 self.mock.assert_awaited_with('foo') 960 961 run(self._runnable_test('SomethingElse')) 962 with self.assertRaises(AssertionError): 963 self.mock.assert_awaited_with('foo') 964 965 def test_assert_awaited_once_with(self): 966 with self.assertRaises(AssertionError): 967 self.mock.assert_awaited_once_with('foo') 968 969 run(self._runnable_test('foo')) 970 self.mock.assert_awaited_once_with('foo') 971 972 run(self._runnable_test('foo')) 973 with self.assertRaises(AssertionError): 974 self.mock.assert_awaited_once_with('foo') 975 976 def test_assert_any_wait(self): 977 with self.assertRaises(AssertionError): 978 self.mock.assert_any_await('foo') 979 980 run(self._runnable_test('baz')) 981 with self.assertRaises(AssertionError): 982 self.mock.assert_any_await('foo') 983 984 run(self._runnable_test('foo')) 985 self.mock.assert_any_await('foo') 986 987 run(self._runnable_test('SomethingElse')) 988 self.mock.assert_any_await('foo') 989 990 def test_assert_has_awaits_no_order(self): 991 calls = [call('foo'), call('baz')] 992 993 with self.assertRaises(AssertionError) as cm: 994 self.mock.assert_has_awaits(calls) 995 self.assertEqual(len(cm.exception.args), 1) 996 997 run(self._runnable_test('foo')) 998 with self.assertRaises(AssertionError): 999 self.mock.assert_has_awaits(calls) 1000 1001 run(self._runnable_test('foo')) 1002 with self.assertRaises(AssertionError): 1003 self.mock.assert_has_awaits(calls) 1004 1005 run(self._runnable_test('baz')) 1006 self.mock.assert_has_awaits(calls) 1007 1008 run(self._runnable_test('SomethingElse')) 1009 self.mock.assert_has_awaits(calls) 1010 1011 def test_awaits_asserts_with_any(self): 1012 class Foo: 1013 def __eq__(self, other): pass 1014 1015 run(self._runnable_test(Foo(), 1)) 1016 1017 self.mock.assert_has_awaits([call(ANY, 1)]) 1018 self.mock.assert_awaited_with(ANY, 1) 1019 self.mock.assert_any_await(ANY, 1) 1020 1021 def test_awaits_asserts_with_spec_and_any(self): 1022 class Foo: 1023 def __eq__(self, other): pass 1024 1025 mock_with_spec = AsyncMock(spec=Foo) 1026 1027 async def _custom_mock_runnable_test(*args): 1028 await mock_with_spec(*args) 1029 1030 run(_custom_mock_runnable_test(Foo(), 1)) 1031 mock_with_spec.assert_has_awaits([call(ANY, 1)]) 1032 mock_with_spec.assert_awaited_with(ANY, 1) 1033 mock_with_spec.assert_any_await(ANY, 1) 1034 1035 def test_assert_has_awaits_ordered(self): 1036 calls = [call('foo'), call('baz')] 1037 with self.assertRaises(AssertionError): 1038 self.mock.assert_has_awaits(calls, any_order=True) 1039 1040 run(self._runnable_test('baz')) 1041 with self.assertRaises(AssertionError): 1042 self.mock.assert_has_awaits(calls, any_order=True) 1043 1044 run(self._runnable_test('bamf')) 1045 with self.assertRaises(AssertionError): 1046 self.mock.assert_has_awaits(calls, any_order=True) 1047 1048 run(self._runnable_test('foo')) 1049 self.mock.assert_has_awaits(calls, any_order=True) 1050 1051 run(self._runnable_test('qux')) 1052 self.mock.assert_has_awaits(calls, any_order=True) 1053 1054 def test_assert_not_awaited(self): 1055 self.mock.assert_not_awaited() 1056 1057 run(self._runnable_test()) 1058 with self.assertRaises(AssertionError): 1059 self.mock.assert_not_awaited() 1060 1061 def test_assert_has_awaits_not_matching_spec_error(self): 1062 async def f(x=None): pass 1063 1064 self.mock = AsyncMock(spec=f) 1065 run(self._runnable_test(1)) 1066 1067 with self.assertRaisesRegex( 1068 AssertionError, 1069 '^{}$'.format( 1070 re.escape('Awaits not found.\n' 1071 'Expected: [call()]\n' 1072 'Actual: [call(1)]'))) as cm: 1073 self.mock.assert_has_awaits([call()]) 1074 self.assertIsNone(cm.exception.__cause__) 1075 1076 with self.assertRaisesRegex( 1077 AssertionError, 1078 '^{}$'.format( 1079 re.escape( 1080 'Error processing expected awaits.\n' 1081 "Errors: [None, TypeError('too many positional " 1082 "arguments')]\n" 1083 'Expected: [call(), call(1, 2)]\n' 1084 'Actual: [call(1)]'))) as cm: 1085 self.mock.assert_has_awaits([call(), call(1, 2)]) 1086 self.assertIsInstance(cm.exception.__cause__, TypeError) 1087 1088 1089if __name__ == '__main__': 1090 unittest.main() 1091