1import asyncio
2import gc
3import inspect
4import re
5import unittest
6from contextlib import contextmanager
7from test import support
8
9support.requires_working_socket(module=True)
10
11from asyncio import run, iscoroutinefunction
12from unittest import IsolatedAsyncioTestCase
13from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock, Mock,
14                           create_autospec, sentinel, _CallList, seal)
15
16
17def tearDownModule():
18    asyncio.set_event_loop_policy(None)
19
20
21class AsyncClass:
22    def __init__(self): pass
23    async def async_method(self): pass
24    def normal_method(self): pass
25
26    @classmethod
27    async def async_class_method(cls): pass
28
29    @staticmethod
30    async def async_static_method(): pass
31
32
33class AwaitableClass:
34    def __await__(self): yield
35
36async def async_func(): pass
37
38async def async_func_args(a, b, *, c): pass
39
40def normal_func(): pass
41
42class NormalClass(object):
43    def a(self): pass
44
45
46async_foo_name = f'{__name__}.AsyncClass'
47normal_foo_name = f'{__name__}.NormalClass'
48
49
50@contextmanager
51def assertNeverAwaited(test):
52    with test.assertWarnsRegex(RuntimeWarning, "was never awaited$"):
53        yield
54        # In non-CPython implementations of Python, this is needed because timely
55        # deallocation is not guaranteed by the garbage collector.
56        gc.collect()
57
58
59class AsyncPatchDecoratorTest(unittest.TestCase):
60    def test_is_coroutine_function_patch(self):
61        @patch.object(AsyncClass, 'async_method')
62        def test_async(mock_method):
63            self.assertTrue(iscoroutinefunction(mock_method))
64        test_async()
65
66    def test_is_async_patch(self):
67        @patch.object(AsyncClass, 'async_method')
68        def test_async(mock_method):
69            m = mock_method()
70            self.assertTrue(inspect.isawaitable(m))
71            run(m)
72
73        @patch(f'{async_foo_name}.async_method')
74        def test_no_parent_attribute(mock_method):
75            m = mock_method()
76            self.assertTrue(inspect.isawaitable(m))
77            run(m)
78
79        test_async()
80        test_no_parent_attribute()
81
82    def test_is_AsyncMock_patch(self):
83        @patch.object(AsyncClass, 'async_method')
84        def test_async(mock_method):
85            self.assertIsInstance(mock_method, AsyncMock)
86
87        test_async()
88
89    def test_is_AsyncMock_patch_staticmethod(self):
90        @patch.object(AsyncClass, 'async_static_method')
91        def test_async(mock_method):
92            self.assertIsInstance(mock_method, AsyncMock)
93
94        test_async()
95
96    def test_is_AsyncMock_patch_classmethod(self):
97        @patch.object(AsyncClass, 'async_class_method')
98        def test_async(mock_method):
99            self.assertIsInstance(mock_method, AsyncMock)
100
101        test_async()
102
103    def test_async_def_patch(self):
104        @patch(f"{__name__}.async_func", return_value=1)
105        @patch(f"{__name__}.async_func_args", return_value=2)
106        async def test_async(func_args_mock, func_mock):
107            self.assertEqual(func_args_mock._mock_name, "async_func_args")
108            self.assertEqual(func_mock._mock_name, "async_func")
109
110            self.assertIsInstance(async_func, AsyncMock)
111            self.assertIsInstance(async_func_args, AsyncMock)
112
113            self.assertEqual(await async_func(), 1)
114            self.assertEqual(await async_func_args(1, 2, c=3), 2)
115
116        run(test_async())
117        self.assertTrue(inspect.iscoroutinefunction(async_func))
118
119
120class AsyncPatchCMTest(unittest.TestCase):
121    def test_is_async_function_cm(self):
122        def test_async():
123            with patch.object(AsyncClass, 'async_method') as mock_method:
124                self.assertTrue(iscoroutinefunction(mock_method))
125
126        test_async()
127
128    def test_is_async_cm(self):
129        def test_async():
130            with patch.object(AsyncClass, 'async_method') as mock_method:
131                m = mock_method()
132                self.assertTrue(inspect.isawaitable(m))
133                run(m)
134
135        test_async()
136
137    def test_is_AsyncMock_cm(self):
138        def test_async():
139            with patch.object(AsyncClass, 'async_method') as mock_method:
140                self.assertIsInstance(mock_method, AsyncMock)
141
142        test_async()
143
144    def test_async_def_cm(self):
145        async def test_async():
146            with patch(f"{__name__}.async_func", AsyncMock()):
147                self.assertIsInstance(async_func, AsyncMock)
148            self.assertTrue(inspect.iscoroutinefunction(async_func))
149
150        run(test_async())
151
152    def test_patch_dict_async_def(self):
153        foo = {'a': 'a'}
154        @patch.dict(foo, {'a': 'b'})
155        async def test_async():
156            self.assertEqual(foo['a'], 'b')
157
158        self.assertTrue(iscoroutinefunction(test_async))
159        run(test_async())
160
161    def test_patch_dict_async_def_context(self):
162        foo = {'a': 'a'}
163        async def test_async():
164            with patch.dict(foo, {'a': 'b'}):
165                self.assertEqual(foo['a'], 'b')
166
167        run(test_async())
168
169
170class AsyncMockTest(unittest.TestCase):
171    def test_iscoroutinefunction_default(self):
172        mock = AsyncMock()
173        self.assertTrue(iscoroutinefunction(mock))
174
175    def test_iscoroutinefunction_function(self):
176        async def foo(): pass
177        mock = AsyncMock(foo)
178        self.assertTrue(iscoroutinefunction(mock))
179        self.assertTrue(inspect.iscoroutinefunction(mock))
180
181    def test_isawaitable(self):
182        mock = AsyncMock()
183        m = mock()
184        self.assertTrue(inspect.isawaitable(m))
185        run(m)
186        self.assertIn('assert_awaited', dir(mock))
187
188    def test_iscoroutinefunction_normal_function(self):
189        def foo(): pass
190        mock = AsyncMock(foo)
191        self.assertTrue(iscoroutinefunction(mock))
192        self.assertTrue(inspect.iscoroutinefunction(mock))
193
194    def test_future_isfuture(self):
195        loop = asyncio.new_event_loop()
196        fut = loop.create_future()
197        loop.stop()
198        loop.close()
199        mock = AsyncMock(fut)
200        self.assertIsInstance(mock, asyncio.Future)
201
202
203class AsyncAutospecTest(unittest.TestCase):
204    def test_is_AsyncMock_patch(self):
205        @patch(async_foo_name, autospec=True)
206        def test_async(mock_method):
207            self.assertIsInstance(mock_method.async_method, AsyncMock)
208            self.assertIsInstance(mock_method, MagicMock)
209
210        @patch(async_foo_name, autospec=True)
211        def test_normal_method(mock_method):
212            self.assertIsInstance(mock_method.normal_method, MagicMock)
213
214        test_async()
215        test_normal_method()
216
217    def test_create_autospec_instance(self):
218        with self.assertRaises(RuntimeError):
219            create_autospec(async_func, instance=True)
220
221    @unittest.skip('Broken test from https://bugs.python.org/issue37251')
222    def test_create_autospec_awaitable_class(self):
223        self.assertIsInstance(create_autospec(AwaitableClass), AsyncMock)
224
225    def test_create_autospec(self):
226        spec = create_autospec(async_func_args)
227        awaitable = spec(1, 2, c=3)
228        async def main():
229            await awaitable
230
231        self.assertEqual(spec.await_count, 0)
232        self.assertIsNone(spec.await_args)
233        self.assertEqual(spec.await_args_list, [])
234        spec.assert_not_awaited()
235
236        run(main())
237
238        self.assertTrue(iscoroutinefunction(spec))
239        self.assertTrue(asyncio.iscoroutine(awaitable))
240        self.assertEqual(spec.await_count, 1)
241        self.assertEqual(spec.await_args, call(1, 2, c=3))
242        self.assertEqual(spec.await_args_list, [call(1, 2, c=3)])
243        spec.assert_awaited_once()
244        spec.assert_awaited_once_with(1, 2, c=3)
245        spec.assert_awaited_with(1, 2, c=3)
246        spec.assert_awaited()
247
248        with self.assertRaises(AssertionError):
249            spec.assert_any_await(e=1)
250
251
252    def test_patch_with_autospec(self):
253
254        async def test_async():
255            with patch(f"{__name__}.async_func_args", autospec=True) as mock_method:
256                awaitable = mock_method(1, 2, c=3)
257                self.assertIsInstance(mock_method.mock, AsyncMock)
258
259                self.assertTrue(iscoroutinefunction(mock_method))
260                self.assertTrue(asyncio.iscoroutine(awaitable))
261                self.assertTrue(inspect.isawaitable(awaitable))
262
263                # Verify the default values during mock setup
264                self.assertEqual(mock_method.await_count, 0)
265                self.assertEqual(mock_method.await_args_list, [])
266                self.assertIsNone(mock_method.await_args)
267                mock_method.assert_not_awaited()
268
269                await awaitable
270
271            self.assertEqual(mock_method.await_count, 1)
272            self.assertEqual(mock_method.await_args, call(1, 2, c=3))
273            self.assertEqual(mock_method.await_args_list, [call(1, 2, c=3)])
274            mock_method.assert_awaited_once()
275            mock_method.assert_awaited_once_with(1, 2, c=3)
276            mock_method.assert_awaited_with(1, 2, c=3)
277            mock_method.assert_awaited()
278
279            mock_method.reset_mock()
280            self.assertEqual(mock_method.await_count, 0)
281            self.assertIsNone(mock_method.await_args)
282            self.assertEqual(mock_method.await_args_list, [])
283
284        run(test_async())
285
286
287class AsyncSpecTest(unittest.TestCase):
288    def test_spec_normal_methods_on_class(self):
289        def inner_test(mock_type):
290            mock = mock_type(AsyncClass)
291            self.assertIsInstance(mock.async_method, AsyncMock)
292            self.assertIsInstance(mock.normal_method, MagicMock)
293
294        for mock_type in [AsyncMock, MagicMock]:
295            with self.subTest(f"test method types with {mock_type}"):
296                inner_test(mock_type)
297
298    def test_spec_normal_methods_on_class_with_mock(self):
299        mock = Mock(AsyncClass)
300        self.assertIsInstance(mock.async_method, AsyncMock)
301        self.assertIsInstance(mock.normal_method, Mock)
302
303    def test_spec_normal_methods_on_class_with_mock_seal(self):
304        mock = Mock(AsyncClass)
305        seal(mock)
306        with self.assertRaises(AttributeError):
307            mock.normal_method
308        with self.assertRaises(AttributeError):
309            mock.async_method
310
311    def test_spec_mock_type_kw(self):
312        def inner_test(mock_type):
313            async_mock = mock_type(spec=async_func)
314            self.assertIsInstance(async_mock, mock_type)
315            with assertNeverAwaited(self):
316                self.assertTrue(inspect.isawaitable(async_mock()))
317
318            sync_mock = mock_type(spec=normal_func)
319            self.assertIsInstance(sync_mock, mock_type)
320
321        for mock_type in [AsyncMock, MagicMock, Mock]:
322            with self.subTest(f"test spec kwarg with {mock_type}"):
323                inner_test(mock_type)
324
325    def test_spec_mock_type_positional(self):
326        def inner_test(mock_type):
327            async_mock = mock_type(async_func)
328            self.assertIsInstance(async_mock, mock_type)
329            with assertNeverAwaited(self):
330                self.assertTrue(inspect.isawaitable(async_mock()))
331
332            sync_mock = mock_type(normal_func)
333            self.assertIsInstance(sync_mock, mock_type)
334
335        for mock_type in [AsyncMock, MagicMock, Mock]:
336            with self.subTest(f"test spec positional with {mock_type}"):
337                inner_test(mock_type)
338
339    def test_spec_as_normal_kw_AsyncMock(self):
340        mock = AsyncMock(spec=normal_func)
341        self.assertIsInstance(mock, AsyncMock)
342        m = mock()
343        self.assertTrue(inspect.isawaitable(m))
344        run(m)
345
346    def test_spec_as_normal_positional_AsyncMock(self):
347        mock = AsyncMock(normal_func)
348        self.assertIsInstance(mock, AsyncMock)
349        m = mock()
350        self.assertTrue(inspect.isawaitable(m))
351        run(m)
352
353    def test_spec_async_mock(self):
354        @patch.object(AsyncClass, 'async_method', spec=True)
355        def test_async(mock_method):
356            self.assertIsInstance(mock_method, AsyncMock)
357
358        test_async()
359
360    def test_spec_parent_not_async_attribute_is(self):
361        @patch(async_foo_name, spec=True)
362        def test_async(mock_method):
363            self.assertIsInstance(mock_method, MagicMock)
364            self.assertIsInstance(mock_method.async_method, AsyncMock)
365
366        test_async()
367
368    def test_target_async_spec_not(self):
369        @patch.object(AsyncClass, 'async_method', spec=NormalClass.a)
370        def test_async_attribute(mock_method):
371            self.assertIsInstance(mock_method, MagicMock)
372            self.assertFalse(inspect.iscoroutine(mock_method))
373            self.assertFalse(inspect.isawaitable(mock_method))
374
375        test_async_attribute()
376
377    def test_target_not_async_spec_is(self):
378        @patch.object(NormalClass, 'a', spec=async_func)
379        def test_attribute_not_async_spec_is(mock_async_func):
380            self.assertIsInstance(mock_async_func, AsyncMock)
381        test_attribute_not_async_spec_is()
382
383    def test_spec_async_attributes(self):
384        @patch(normal_foo_name, spec=AsyncClass)
385        def test_async_attributes_coroutines(MockNormalClass):
386            self.assertIsInstance(MockNormalClass.async_method, AsyncMock)
387            self.assertIsInstance(MockNormalClass, MagicMock)
388
389        test_async_attributes_coroutines()
390
391
392class AsyncSpecSetTest(unittest.TestCase):
393    def test_is_AsyncMock_patch(self):
394        @patch.object(AsyncClass, 'async_method', spec_set=True)
395        def test_async(async_method):
396            self.assertIsInstance(async_method, AsyncMock)
397        test_async()
398
399    def test_is_async_AsyncMock(self):
400        mock = AsyncMock(spec_set=AsyncClass.async_method)
401        self.assertTrue(iscoroutinefunction(mock))
402        self.assertIsInstance(mock, AsyncMock)
403
404    def test_is_child_AsyncMock(self):
405        mock = MagicMock(spec_set=AsyncClass)
406        self.assertTrue(iscoroutinefunction(mock.async_method))
407        self.assertFalse(iscoroutinefunction(mock.normal_method))
408        self.assertIsInstance(mock.async_method, AsyncMock)
409        self.assertIsInstance(mock.normal_method, MagicMock)
410        self.assertIsInstance(mock, MagicMock)
411
412    def test_magicmock_lambda_spec(self):
413        mock_obj = MagicMock()
414        mock_obj.mock_func = MagicMock(spec=lambda x: x)
415
416        with patch.object(mock_obj, "mock_func") as cm:
417            self.assertIsInstance(cm, MagicMock)
418
419
420class AsyncArguments(IsolatedAsyncioTestCase):
421    async def test_add_return_value(self):
422        async def addition(self, var): pass
423
424        mock = AsyncMock(addition, return_value=10)
425        output = await mock(5)
426
427        self.assertEqual(output, 10)
428
429    async def test_add_side_effect_exception(self):
430        async def addition(var): pass
431        mock = AsyncMock(addition, side_effect=Exception('err'))
432        with self.assertRaises(Exception):
433            await mock(5)
434
435    async def test_add_side_effect_coroutine(self):
436        async def addition(var):
437            return var + 1
438        mock = AsyncMock(side_effect=addition)
439        result = await mock(5)
440        self.assertEqual(result, 6)
441
442    async def test_add_side_effect_normal_function(self):
443        def addition(var):
444            return var + 1
445        mock = AsyncMock(side_effect=addition)
446        result = await mock(5)
447        self.assertEqual(result, 6)
448
449    async def test_add_side_effect_iterable(self):
450        vals = [1, 2, 3]
451        mock = AsyncMock(side_effect=vals)
452        for item in vals:
453            self.assertEqual(await mock(), item)
454
455        with self.assertRaises(StopAsyncIteration) as e:
456            await mock()
457
458    async def test_add_side_effect_exception_iterable(self):
459        class SampleException(Exception):
460            pass
461
462        vals = [1, SampleException("foo")]
463        mock = AsyncMock(side_effect=vals)
464        self.assertEqual(await mock(), 1)
465
466        with self.assertRaises(SampleException) as e:
467            await mock()
468
469    async def test_return_value_AsyncMock(self):
470        value = AsyncMock(return_value=10)
471        mock = AsyncMock(return_value=value)
472        result = await mock()
473        self.assertIs(result, value)
474
475    async def test_return_value_awaitable(self):
476        fut = asyncio.Future()
477        fut.set_result(None)
478        mock = AsyncMock(return_value=fut)
479        result = await mock()
480        self.assertIsInstance(result, asyncio.Future)
481
482    async def test_side_effect_awaitable_values(self):
483        fut = asyncio.Future()
484        fut.set_result(None)
485
486        mock = AsyncMock(side_effect=[fut])
487        result = await mock()
488        self.assertIsInstance(result, asyncio.Future)
489
490        with self.assertRaises(StopAsyncIteration):
491            await mock()
492
493    async def test_side_effect_is_AsyncMock(self):
494        effect = AsyncMock(return_value=10)
495        mock = AsyncMock(side_effect=effect)
496
497        result = await mock()
498        self.assertEqual(result, 10)
499
500    async def test_wraps_coroutine(self):
501        value = asyncio.Future()
502
503        ran = False
504        async def inner():
505            nonlocal ran
506            ran = True
507            return value
508
509        mock = AsyncMock(wraps=inner)
510        result = await mock()
511        self.assertEqual(result, value)
512        mock.assert_awaited()
513        self.assertTrue(ran)
514
515    async def test_wraps_normal_function(self):
516        value = 1
517
518        ran = False
519        def inner():
520            nonlocal ran
521            ran = True
522            return value
523
524        mock = AsyncMock(wraps=inner)
525        result = await mock()
526        self.assertEqual(result, value)
527        mock.assert_awaited()
528        self.assertTrue(ran)
529
530    async def test_await_args_list_order(self):
531        async_mock = AsyncMock()
532        mock2 = async_mock(2)
533        mock1 = async_mock(1)
534        await mock1
535        await mock2
536        async_mock.assert_has_awaits([call(1), call(2)])
537        self.assertEqual(async_mock.await_args_list, [call(1), call(2)])
538        self.assertEqual(async_mock.call_args_list, [call(2), call(1)])
539
540
541class AsyncMagicMethods(unittest.TestCase):
542    def test_async_magic_methods_return_async_mocks(self):
543        m_mock = MagicMock()
544        self.assertIsInstance(m_mock.__aenter__, AsyncMock)
545        self.assertIsInstance(m_mock.__aexit__, AsyncMock)
546        self.assertIsInstance(m_mock.__anext__, AsyncMock)
547        # __aiter__ is actually a synchronous object
548        # so should return a MagicMock
549        self.assertIsInstance(m_mock.__aiter__, MagicMock)
550
551    def test_sync_magic_methods_return_magic_mocks(self):
552        a_mock = AsyncMock()
553        self.assertIsInstance(a_mock.__enter__, MagicMock)
554        self.assertIsInstance(a_mock.__exit__, MagicMock)
555        self.assertIsInstance(a_mock.__next__, MagicMock)
556        self.assertIsInstance(a_mock.__len__, MagicMock)
557
558    def test_magicmock_has_async_magic_methods(self):
559        m_mock = MagicMock()
560        self.assertTrue(hasattr(m_mock, "__aenter__"))
561        self.assertTrue(hasattr(m_mock, "__aexit__"))
562        self.assertTrue(hasattr(m_mock, "__anext__"))
563
564    def test_asyncmock_has_sync_magic_methods(self):
565        a_mock = AsyncMock()
566        self.assertTrue(hasattr(a_mock, "__enter__"))
567        self.assertTrue(hasattr(a_mock, "__exit__"))
568        self.assertTrue(hasattr(a_mock, "__next__"))
569        self.assertTrue(hasattr(a_mock, "__len__"))
570
571    def test_magic_methods_are_async_functions(self):
572        m_mock = MagicMock()
573        self.assertIsInstance(m_mock.__aenter__, AsyncMock)
574        self.assertIsInstance(m_mock.__aexit__, AsyncMock)
575        # AsyncMocks are also coroutine functions
576        self.assertTrue(iscoroutinefunction(m_mock.__aenter__))
577        self.assertTrue(iscoroutinefunction(m_mock.__aexit__))
578
579class AsyncContextManagerTest(unittest.TestCase):
580
581    class WithAsyncContextManager:
582        async def __aenter__(self, *args, **kwargs): pass
583
584        async def __aexit__(self, *args, **kwargs): pass
585
586    class WithSyncContextManager:
587        def __enter__(self, *args, **kwargs): pass
588
589        def __exit__(self, *args, **kwargs): pass
590
591    class ProductionCode:
592        # Example real-world(ish) code
593        def __init__(self):
594            self.session = None
595
596        async def main(self):
597            async with self.session.post('https://python.org') as response:
598                val = await response.json()
599                return val
600
601    def test_set_return_value_of_aenter(self):
602        def inner_test(mock_type):
603            pc = self.ProductionCode()
604            pc.session = MagicMock(name='sessionmock')
605            cm = mock_type(name='magic_cm')
606            response = AsyncMock(name='response')
607            response.json = AsyncMock(return_value={'json': 123})
608            cm.__aenter__.return_value = response
609            pc.session.post.return_value = cm
610            result = run(pc.main())
611            self.assertEqual(result, {'json': 123})
612
613        for mock_type in [AsyncMock, MagicMock]:
614            with self.subTest(f"test set return value of aenter with {mock_type}"):
615                inner_test(mock_type)
616
617    def test_mock_supports_async_context_manager(self):
618        def inner_test(mock_type):
619            called = False
620            cm = self.WithAsyncContextManager()
621            cm_mock = mock_type(cm)
622
623            async def use_context_manager():
624                nonlocal called
625                async with cm_mock as result:
626                    called = True
627                return result
628
629            cm_result = run(use_context_manager())
630            self.assertTrue(called)
631            self.assertTrue(cm_mock.__aenter__.called)
632            self.assertTrue(cm_mock.__aexit__.called)
633            cm_mock.__aenter__.assert_awaited()
634            cm_mock.__aexit__.assert_awaited()
635            # We mock __aenter__ so it does not return self
636            self.assertIsNot(cm_mock, cm_result)
637
638        for mock_type in [AsyncMock, MagicMock]:
639            with self.subTest(f"test context manager magics with {mock_type}"):
640                inner_test(mock_type)
641
642
643    def test_mock_customize_async_context_manager(self):
644        instance = self.WithAsyncContextManager()
645        mock_instance = MagicMock(instance)
646
647        expected_result = object()
648        mock_instance.__aenter__.return_value = expected_result
649
650        async def use_context_manager():
651            async with mock_instance as result:
652                return result
653
654        self.assertIs(run(use_context_manager()), expected_result)
655
656    def test_mock_customize_async_context_manager_with_coroutine(self):
657        enter_called = False
658        exit_called = False
659
660        async def enter_coroutine(*args):
661            nonlocal enter_called
662            enter_called = True
663
664        async def exit_coroutine(*args):
665            nonlocal exit_called
666            exit_called = True
667
668        instance = self.WithAsyncContextManager()
669        mock_instance = MagicMock(instance)
670
671        mock_instance.__aenter__ = enter_coroutine
672        mock_instance.__aexit__ = exit_coroutine
673
674        async def use_context_manager():
675            async with mock_instance:
676                pass
677
678        run(use_context_manager())
679        self.assertTrue(enter_called)
680        self.assertTrue(exit_called)
681
682    def test_context_manager_raise_exception_by_default(self):
683        async def raise_in(context_manager):
684            async with context_manager:
685                raise TypeError()
686
687        instance = self.WithAsyncContextManager()
688        mock_instance = MagicMock(instance)
689        with self.assertRaises(TypeError):
690            run(raise_in(mock_instance))
691
692
693class AsyncIteratorTest(unittest.TestCase):
694    class WithAsyncIterator(object):
695        def __init__(self):
696            self.items = ["foo", "NormalFoo", "baz"]
697
698        def __aiter__(self): pass
699
700        async def __anext__(self): pass
701
702    def test_aiter_set_return_value(self):
703        mock_iter = AsyncMock(name="tester")
704        mock_iter.__aiter__.return_value = [1, 2, 3]
705        async def main():
706            return [i async for i in mock_iter]
707        result = run(main())
708        self.assertEqual(result, [1, 2, 3])
709
710    def test_mock_aiter_and_anext_asyncmock(self):
711        def inner_test(mock_type):
712            instance = self.WithAsyncIterator()
713            mock_instance = mock_type(instance)
714            # Check that the mock and the real thing bahave the same
715            # __aiter__ is not actually async, so not a coroutinefunction
716            self.assertFalse(iscoroutinefunction(instance.__aiter__))
717            self.assertFalse(iscoroutinefunction(mock_instance.__aiter__))
718            # __anext__ is async
719            self.assertTrue(iscoroutinefunction(instance.__anext__))
720            self.assertTrue(iscoroutinefunction(mock_instance.__anext__))
721
722        for mock_type in [AsyncMock, MagicMock]:
723            with self.subTest(f"test aiter and anext corourtine with {mock_type}"):
724                inner_test(mock_type)
725
726
727    def test_mock_async_for(self):
728        async def iterate(iterator):
729            accumulator = []
730            async for item in iterator:
731                accumulator.append(item)
732
733            return accumulator
734
735        expected = ["FOO", "BAR", "BAZ"]
736        def test_default(mock_type):
737            mock_instance = mock_type(self.WithAsyncIterator())
738            self.assertEqual(run(iterate(mock_instance)), [])
739
740
741        def test_set_return_value(mock_type):
742            mock_instance = mock_type(self.WithAsyncIterator())
743            mock_instance.__aiter__.return_value = expected[:]
744            self.assertEqual(run(iterate(mock_instance)), expected)
745
746        def test_set_return_value_iter(mock_type):
747            mock_instance = mock_type(self.WithAsyncIterator())
748            mock_instance.__aiter__.return_value = iter(expected[:])
749            self.assertEqual(run(iterate(mock_instance)), expected)
750
751        for mock_type in [AsyncMock, MagicMock]:
752            with self.subTest(f"default value with {mock_type}"):
753                test_default(mock_type)
754
755            with self.subTest(f"set return_value with {mock_type}"):
756                test_set_return_value(mock_type)
757
758            with self.subTest(f"set return_value iterator with {mock_type}"):
759                test_set_return_value_iter(mock_type)
760
761
762class AsyncMockAssert(unittest.TestCase):
763    def setUp(self):
764        self.mock = AsyncMock()
765
766    async def _runnable_test(self, *args, **kwargs):
767        await self.mock(*args, **kwargs)
768
769    async def _await_coroutine(self, coroutine):
770        return await coroutine
771
772    def test_assert_called_but_not_awaited(self):
773        mock = AsyncMock(AsyncClass)
774        with assertNeverAwaited(self):
775            mock.async_method()
776        self.assertTrue(iscoroutinefunction(mock.async_method))
777        mock.async_method.assert_called()
778        mock.async_method.assert_called_once()
779        mock.async_method.assert_called_once_with()
780        with self.assertRaises(AssertionError):
781            mock.assert_awaited()
782        with self.assertRaises(AssertionError):
783            mock.async_method.assert_awaited()
784
785    def test_assert_called_then_awaited(self):
786        mock = AsyncMock(AsyncClass)
787        mock_coroutine = mock.async_method()
788        mock.async_method.assert_called()
789        mock.async_method.assert_called_once()
790        mock.async_method.assert_called_once_with()
791        with self.assertRaises(AssertionError):
792            mock.async_method.assert_awaited()
793
794        run(self._await_coroutine(mock_coroutine))
795        # Assert we haven't re-called the function
796        mock.async_method.assert_called_once()
797        mock.async_method.assert_awaited()
798        mock.async_method.assert_awaited_once()
799        mock.async_method.assert_awaited_once_with()
800
801    def test_assert_called_and_awaited_at_same_time(self):
802        with self.assertRaises(AssertionError):
803            self.mock.assert_awaited()
804
805        with self.assertRaises(AssertionError):
806            self.mock.assert_called()
807
808        run(self._runnable_test())
809        self.mock.assert_called_once()
810        self.mock.assert_awaited_once()
811
812    def test_assert_called_twice_and_awaited_once(self):
813        mock = AsyncMock(AsyncClass)
814        coroutine = mock.async_method()
815        # The first call will be awaited so no warning there
816        # But this call will never get awaited, so it will warn here
817        with assertNeverAwaited(self):
818            mock.async_method()
819        with self.assertRaises(AssertionError):
820            mock.async_method.assert_awaited()
821        mock.async_method.assert_called()
822        run(self._await_coroutine(coroutine))
823        mock.async_method.assert_awaited()
824        mock.async_method.assert_awaited_once()
825
826    def test_assert_called_once_and_awaited_twice(self):
827        mock = AsyncMock(AsyncClass)
828        coroutine = mock.async_method()
829        mock.async_method.assert_called_once()
830        run(self._await_coroutine(coroutine))
831        with self.assertRaises(RuntimeError):
832            # Cannot reuse already awaited coroutine
833            run(self._await_coroutine(coroutine))
834        mock.async_method.assert_awaited()
835
836    def test_assert_awaited_but_not_called(self):
837        with self.assertRaises(AssertionError):
838            self.mock.assert_awaited()
839        with self.assertRaises(AssertionError):
840            self.mock.assert_called()
841        with self.assertRaises(TypeError):
842            # You cannot await an AsyncMock, it must be a coroutine
843            run(self._await_coroutine(self.mock))
844
845        with self.assertRaises(AssertionError):
846            self.mock.assert_awaited()
847        with self.assertRaises(AssertionError):
848            self.mock.assert_called()
849
850    def test_assert_has_calls_not_awaits(self):
851        kalls = [call('foo')]
852        with assertNeverAwaited(self):
853            self.mock('foo')
854        self.mock.assert_has_calls(kalls)
855        with self.assertRaises(AssertionError):
856            self.mock.assert_has_awaits(kalls)
857
858    def test_assert_has_mock_calls_on_async_mock_no_spec(self):
859        with assertNeverAwaited(self):
860            self.mock()
861        kalls_empty = [('', (), {})]
862        self.assertEqual(self.mock.mock_calls, kalls_empty)
863
864        with assertNeverAwaited(self):
865            self.mock('foo')
866        with assertNeverAwaited(self):
867            self.mock('baz')
868        mock_kalls = ([call(), call('foo'), call('baz')])
869        self.assertEqual(self.mock.mock_calls, mock_kalls)
870
871    def test_assert_has_mock_calls_on_async_mock_with_spec(self):
872        a_class_mock = AsyncMock(AsyncClass)
873        with assertNeverAwaited(self):
874            a_class_mock.async_method()
875        kalls_empty = [('', (), {})]
876        self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty)
877        self.assertEqual(a_class_mock.mock_calls, [call.async_method()])
878
879        with assertNeverAwaited(self):
880            a_class_mock.async_method(1, 2, 3, a=4, b=5)
881        method_kalls = [call(), call(1, 2, 3, a=4, b=5)]
882        mock_kalls = [call.async_method(), call.async_method(1, 2, 3, a=4, b=5)]
883        self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls)
884        self.assertEqual(a_class_mock.mock_calls, mock_kalls)
885
886    def test_async_method_calls_recorded(self):
887        with assertNeverAwaited(self):
888            self.mock.something(3, fish=None)
889        with assertNeverAwaited(self):
890            self.mock.something_else.something(6, cake=sentinel.Cake)
891
892        self.assertEqual(self.mock.method_calls, [
893            ("something", (3,), {'fish': None}),
894            ("something_else.something", (6,), {'cake': sentinel.Cake})
895        ],
896            "method calls not recorded correctly")
897        self.assertEqual(self.mock.something_else.method_calls,
898                         [("something", (6,), {'cake': sentinel.Cake})],
899                         "method calls not recorded correctly")
900
901    def test_async_arg_lists(self):
902        def assert_attrs(mock):
903            names = ('call_args_list', 'method_calls', 'mock_calls')
904            for name in names:
905                attr = getattr(mock, name)
906                self.assertIsInstance(attr, _CallList)
907                self.assertIsInstance(attr, list)
908                self.assertEqual(attr, [])
909
910        assert_attrs(self.mock)
911        with assertNeverAwaited(self):
912            self.mock()
913        with assertNeverAwaited(self):
914            self.mock(1, 2)
915        with assertNeverAwaited(self):
916            self.mock(a=3)
917
918        self.mock.reset_mock()
919        assert_attrs(self.mock)
920
921        a_mock = AsyncMock(AsyncClass)
922        with assertNeverAwaited(self):
923            a_mock.async_method()
924        with assertNeverAwaited(self):
925            a_mock.async_method(1, a=3)
926
927        a_mock.reset_mock()
928        assert_attrs(a_mock)
929
930    def test_assert_awaited(self):
931        with self.assertRaises(AssertionError):
932            self.mock.assert_awaited()
933
934        run(self._runnable_test())
935        self.mock.assert_awaited()
936
937    def test_assert_awaited_once(self):
938        with self.assertRaises(AssertionError):
939            self.mock.assert_awaited_once()
940
941        run(self._runnable_test())
942        self.mock.assert_awaited_once()
943
944        run(self._runnable_test())
945        with self.assertRaises(AssertionError):
946            self.mock.assert_awaited_once()
947
948    def test_assert_awaited_with(self):
949        msg = 'Not awaited'
950        with self.assertRaisesRegex(AssertionError, msg):
951            self.mock.assert_awaited_with('foo')
952
953        run(self._runnable_test())
954        msg = 'expected await not found'
955        with self.assertRaisesRegex(AssertionError, msg):
956            self.mock.assert_awaited_with('foo')
957
958        run(self._runnable_test('foo'))
959        self.mock.assert_awaited_with('foo')
960
961        run(self._runnable_test('SomethingElse'))
962        with self.assertRaises(AssertionError):
963            self.mock.assert_awaited_with('foo')
964
965    def test_assert_awaited_once_with(self):
966        with self.assertRaises(AssertionError):
967            self.mock.assert_awaited_once_with('foo')
968
969        run(self._runnable_test('foo'))
970        self.mock.assert_awaited_once_with('foo')
971
972        run(self._runnable_test('foo'))
973        with self.assertRaises(AssertionError):
974            self.mock.assert_awaited_once_with('foo')
975
976    def test_assert_any_wait(self):
977        with self.assertRaises(AssertionError):
978            self.mock.assert_any_await('foo')
979
980        run(self._runnable_test('baz'))
981        with self.assertRaises(AssertionError):
982            self.mock.assert_any_await('foo')
983
984        run(self._runnable_test('foo'))
985        self.mock.assert_any_await('foo')
986
987        run(self._runnable_test('SomethingElse'))
988        self.mock.assert_any_await('foo')
989
990    def test_assert_has_awaits_no_order(self):
991        calls = [call('foo'), call('baz')]
992
993        with self.assertRaises(AssertionError) as cm:
994            self.mock.assert_has_awaits(calls)
995        self.assertEqual(len(cm.exception.args), 1)
996
997        run(self._runnable_test('foo'))
998        with self.assertRaises(AssertionError):
999            self.mock.assert_has_awaits(calls)
1000
1001        run(self._runnable_test('foo'))
1002        with self.assertRaises(AssertionError):
1003            self.mock.assert_has_awaits(calls)
1004
1005        run(self._runnable_test('baz'))
1006        self.mock.assert_has_awaits(calls)
1007
1008        run(self._runnable_test('SomethingElse'))
1009        self.mock.assert_has_awaits(calls)
1010
1011    def test_awaits_asserts_with_any(self):
1012        class Foo:
1013            def __eq__(self, other): pass
1014
1015        run(self._runnable_test(Foo(), 1))
1016
1017        self.mock.assert_has_awaits([call(ANY, 1)])
1018        self.mock.assert_awaited_with(ANY, 1)
1019        self.mock.assert_any_await(ANY, 1)
1020
1021    def test_awaits_asserts_with_spec_and_any(self):
1022        class Foo:
1023            def __eq__(self, other): pass
1024
1025        mock_with_spec = AsyncMock(spec=Foo)
1026
1027        async def _custom_mock_runnable_test(*args):
1028            await mock_with_spec(*args)
1029
1030        run(_custom_mock_runnable_test(Foo(), 1))
1031        mock_with_spec.assert_has_awaits([call(ANY, 1)])
1032        mock_with_spec.assert_awaited_with(ANY, 1)
1033        mock_with_spec.assert_any_await(ANY, 1)
1034
1035    def test_assert_has_awaits_ordered(self):
1036        calls = [call('foo'), call('baz')]
1037        with self.assertRaises(AssertionError):
1038            self.mock.assert_has_awaits(calls, any_order=True)
1039
1040        run(self._runnable_test('baz'))
1041        with self.assertRaises(AssertionError):
1042            self.mock.assert_has_awaits(calls, any_order=True)
1043
1044        run(self._runnable_test('bamf'))
1045        with self.assertRaises(AssertionError):
1046            self.mock.assert_has_awaits(calls, any_order=True)
1047
1048        run(self._runnable_test('foo'))
1049        self.mock.assert_has_awaits(calls, any_order=True)
1050
1051        run(self._runnable_test('qux'))
1052        self.mock.assert_has_awaits(calls, any_order=True)
1053
1054    def test_assert_not_awaited(self):
1055        self.mock.assert_not_awaited()
1056
1057        run(self._runnable_test())
1058        with self.assertRaises(AssertionError):
1059            self.mock.assert_not_awaited()
1060
1061    def test_assert_has_awaits_not_matching_spec_error(self):
1062        async def f(x=None): pass
1063
1064        self.mock = AsyncMock(spec=f)
1065        run(self._runnable_test(1))
1066
1067        with self.assertRaisesRegex(
1068                AssertionError,
1069                '^{}$'.format(
1070                    re.escape('Awaits not found.\n'
1071                              'Expected: [call()]\n'
1072                              'Actual: [call(1)]'))) as cm:
1073            self.mock.assert_has_awaits([call()])
1074        self.assertIsNone(cm.exception.__cause__)
1075
1076        with self.assertRaisesRegex(
1077                AssertionError,
1078                '^{}$'.format(
1079                    re.escape(
1080                        'Error processing expected awaits.\n'
1081                        "Errors: [None, TypeError('too many positional "
1082                        "arguments')]\n"
1083                        'Expected: [call(), call(1, 2)]\n'
1084                        'Actual: [call(1)]'))) as cm:
1085            self.mock.assert_has_awaits([call(), call(1, 2)])
1086        self.assertIsInstance(cm.exception.__cause__, TypeError)
1087
1088
1089if __name__ == '__main__':
1090    unittest.main()
1091