1import unittest
2from test import support
3
4
5class TestMROEntry(unittest.TestCase):
6    def test_mro_entry_signature(self):
7        tested = []
8        class B: ...
9        class C:
10            def __mro_entries__(self, *args, **kwargs):
11                tested.extend([args, kwargs])
12                return (C,)
13        c = C()
14        self.assertEqual(tested, [])
15        class D(B, c): ...
16        self.assertEqual(tested[0], ((B, c),))
17        self.assertEqual(tested[1], {})
18
19    def test_mro_entry(self):
20        tested = []
21        class A: ...
22        class B: ...
23        class C:
24            def __mro_entries__(self, bases):
25                tested.append(bases)
26                return (self.__class__,)
27        c = C()
28        self.assertEqual(tested, [])
29        class D(A, c, B): ...
30        self.assertEqual(tested[-1], (A, c, B))
31        self.assertEqual(D.__bases__, (A, C, B))
32        self.assertEqual(D.__orig_bases__, (A, c, B))
33        self.assertEqual(D.__mro__, (D, A, C, B, object))
34        d = D()
35        class E(d): ...
36        self.assertEqual(tested[-1], (d,))
37        self.assertEqual(E.__bases__, (D,))
38
39    def test_mro_entry_none(self):
40        tested = []
41        class A: ...
42        class B: ...
43        class C:
44            def __mro_entries__(self, bases):
45                tested.append(bases)
46                return ()
47        c = C()
48        self.assertEqual(tested, [])
49        class D(A, c, B): ...
50        self.assertEqual(tested[-1], (A, c, B))
51        self.assertEqual(D.__bases__, (A, B))
52        self.assertEqual(D.__orig_bases__, (A, c, B))
53        self.assertEqual(D.__mro__, (D, A, B, object))
54        class E(c): ...
55        self.assertEqual(tested[-1], (c,))
56        self.assertEqual(E.__bases__, (object,))
57        self.assertEqual(E.__orig_bases__, (c,))
58        self.assertEqual(E.__mro__, (E, object))
59
60    def test_mro_entry_with_builtins(self):
61        tested = []
62        class A: ...
63        class C:
64            def __mro_entries__(self, bases):
65                tested.append(bases)
66                return (dict,)
67        c = C()
68        self.assertEqual(tested, [])
69        class D(A, c): ...
70        self.assertEqual(tested[-1], (A, c))
71        self.assertEqual(D.__bases__, (A, dict))
72        self.assertEqual(D.__orig_bases__, (A, c))
73        self.assertEqual(D.__mro__, (D, A, dict, object))
74
75    def test_mro_entry_with_builtins_2(self):
76        tested = []
77        class C:
78            def __mro_entries__(self, bases):
79                tested.append(bases)
80                return (C,)
81        c = C()
82        self.assertEqual(tested, [])
83        class D(c, dict): ...
84        self.assertEqual(tested[-1], (c, dict))
85        self.assertEqual(D.__bases__, (C, dict))
86        self.assertEqual(D.__orig_bases__, (c, dict))
87        self.assertEqual(D.__mro__, (D, C, dict, object))
88
89    def test_mro_entry_errors(self):
90        class C_too_many:
91            def __mro_entries__(self, bases, something, other):
92                return ()
93        c = C_too_many()
94        with self.assertRaises(TypeError):
95            class D(c): ...
96        class C_too_few:
97            def __mro_entries__(self):
98                return ()
99        d = C_too_few()
100        with self.assertRaises(TypeError):
101            class D(d): ...
102
103    def test_mro_entry_errors_2(self):
104        class C_not_callable:
105            __mro_entries__ = "Surprise!"
106        c = C_not_callable()
107        with self.assertRaises(TypeError):
108            class D(c): ...
109        class C_not_tuple:
110            def __mro_entries__(self):
111                return object
112        c = C_not_tuple()
113        with self.assertRaises(TypeError):
114            class D(c): ...
115
116    def test_mro_entry_metaclass(self):
117        meta_args = []
118        class Meta(type):
119            def __new__(mcls, name, bases, ns):
120                meta_args.extend([mcls, name, bases, ns])
121                return super().__new__(mcls, name, bases, ns)
122        class A: ...
123        class C:
124            def __mro_entries__(self, bases):
125                return (A,)
126        c = C()
127        class D(c, metaclass=Meta):
128            x = 1
129        self.assertEqual(meta_args[0], Meta)
130        self.assertEqual(meta_args[1], 'D')
131        self.assertEqual(meta_args[2], (A,))
132        self.assertEqual(meta_args[3]['x'], 1)
133        self.assertEqual(D.__bases__, (A,))
134        self.assertEqual(D.__orig_bases__, (c,))
135        self.assertEqual(D.__mro__, (D, A, object))
136        self.assertEqual(D.__class__, Meta)
137
138    def test_mro_entry_type_call(self):
139        # Substitution should _not_ happen in direct type call
140        class C:
141            def __mro_entries__(self, bases):
142                return ()
143        c = C()
144        with self.assertRaisesRegex(TypeError,
145                                    "MRO entry resolution; "
146                                    "use types.new_class()"):
147            type('Bad', (c,), {})
148
149
150class TestClassGetitem(unittest.TestCase):
151    def test_class_getitem(self):
152        getitem_args = []
153        class C:
154            def __class_getitem__(*args, **kwargs):
155                getitem_args.extend([args, kwargs])
156                return None
157        C[int, str]
158        self.assertEqual(getitem_args[0], (C, (int, str)))
159        self.assertEqual(getitem_args[1], {})
160
161    def test_class_getitem_format(self):
162        class C:
163            def __class_getitem__(cls, item):
164                return f'C[{item.__name__}]'
165        self.assertEqual(C[int], 'C[int]')
166        self.assertEqual(C[C], 'C[C]')
167
168    def test_class_getitem_inheritance(self):
169        class C:
170            def __class_getitem__(cls, item):
171                return f'{cls.__name__}[{item.__name__}]'
172        class D(C): ...
173        self.assertEqual(D[int], 'D[int]')
174        self.assertEqual(D[D], 'D[D]')
175
176    def test_class_getitem_inheritance_2(self):
177        class C:
178            def __class_getitem__(cls, item):
179                return 'Should not see this'
180        class D(C):
181            def __class_getitem__(cls, item):
182                return f'{cls.__name__}[{item.__name__}]'
183        self.assertEqual(D[int], 'D[int]')
184        self.assertEqual(D[D], 'D[D]')
185
186    def test_class_getitem_classmethod(self):
187        class C:
188            @classmethod
189            def __class_getitem__(cls, item):
190                return f'{cls.__name__}[{item.__name__}]'
191        class D(C): ...
192        self.assertEqual(D[int], 'D[int]')
193        self.assertEqual(D[D], 'D[D]')
194
195    def test_class_getitem_patched(self):
196        class C:
197            def __init_subclass__(cls):
198                def __class_getitem__(cls, item):
199                    return f'{cls.__name__}[{item.__name__}]'
200                cls.__class_getitem__ = classmethod(__class_getitem__)
201        class D(C): ...
202        self.assertEqual(D[int], 'D[int]')
203        self.assertEqual(D[D], 'D[D]')
204
205    def test_class_getitem_with_builtins(self):
206        class A(dict):
207            called_with = None
208
209            def __class_getitem__(cls, item):
210                cls.called_with = item
211        class B(A):
212            pass
213        self.assertIs(B.called_with, None)
214        B[int]
215        self.assertIs(B.called_with, int)
216
217    def test_class_getitem_errors(self):
218        class C_too_few:
219            def __class_getitem__(cls):
220                return None
221        with self.assertRaises(TypeError):
222            C_too_few[int]
223
224        class C_too_many:
225            def __class_getitem__(cls, one, two):
226                return None
227        with self.assertRaises(TypeError):
228            C_too_many[int]
229
230    def test_class_getitem_errors_2(self):
231        class C:
232            def __class_getitem__(cls, item):
233                return None
234        with self.assertRaises(TypeError):
235            C()[int]
236
237        class E: ...
238        e = E()
239        e.__class_getitem__ = lambda cls, item: 'This will not work'
240        with self.assertRaises(TypeError):
241            e[int]
242
243        class C_not_callable:
244            __class_getitem__ = "Surprise!"
245        with self.assertRaises(TypeError):
246            C_not_callable[int]
247
248        class C_is_none(tuple):
249            __class_getitem__ = None
250        with self.assertRaisesRegex(TypeError, "C_is_none"):
251            C_is_none[int]
252
253    def test_class_getitem_metaclass(self):
254        class Meta(type):
255            def __class_getitem__(cls, item):
256                return f'{cls.__name__}[{item.__name__}]'
257        self.assertEqual(Meta[int], 'Meta[int]')
258
259    def test_class_getitem_with_metaclass(self):
260        class Meta(type): pass
261        class C(metaclass=Meta):
262            def __class_getitem__(cls, item):
263                return f'{cls.__name__}[{item.__name__}]'
264        self.assertEqual(C[int], 'C[int]')
265
266    def test_class_getitem_metaclass_first(self):
267        class Meta(type):
268            def __getitem__(cls, item):
269                return 'from metaclass'
270        class C(metaclass=Meta):
271            def __class_getitem__(cls, item):
272                return 'from __class_getitem__'
273        self.assertEqual(C[int], 'from metaclass')
274
275
276@support.cpython_only
277class CAPITest(unittest.TestCase):
278
279    def test_c_class(self):
280        from _testcapi import Generic, GenericAlias
281        self.assertIsInstance(Generic.__class_getitem__(int), GenericAlias)
282
283        IntGeneric = Generic[int]
284        self.assertIs(type(IntGeneric), GenericAlias)
285        self.assertEqual(IntGeneric.__mro_entries__(()), (int,))
286        class C(IntGeneric):
287            pass
288        self.assertEqual(C.__bases__, (int,))
289        self.assertEqual(C.__orig_bases__, (IntGeneric,))
290        self.assertEqual(C.__mro__, (C, int, object))
291
292
293if __name__ == "__main__":
294    unittest.main()
295