1import sys
2import unittest
3from doctest import DocTestSuite
4from test import support
5from test.support import threading_helper
6from test.support.import_helper import import_module
7import weakref
8import gc
9
10# Modules under test
11import _thread
12import threading
13import _threading_local
14
15
16threading_helper.requires_working_threading(module=True)
17
18
19class Weak(object):
20    pass
21
22def target(local, weaklist):
23    weak = Weak()
24    local.weak = weak
25    weaklist.append(weakref.ref(weak))
26
27
28class BaseLocalTest:
29
30    def test_local_refs(self):
31        self._local_refs(20)
32        self._local_refs(50)
33        self._local_refs(100)
34
35    def _local_refs(self, n):
36        local = self._local()
37        weaklist = []
38        for i in range(n):
39            t = threading.Thread(target=target, args=(local, weaklist))
40            t.start()
41            t.join()
42        del t
43
44        support.gc_collect()  # For PyPy or other GCs.
45        self.assertEqual(len(weaklist), n)
46
47        # XXX _threading_local keeps the local of the last stopped thread alive.
48        deadlist = [weak for weak in weaklist if weak() is None]
49        self.assertIn(len(deadlist), (n-1, n))
50
51        # Assignment to the same thread local frees it sometimes (!)
52        local.someothervar = None
53        support.gc_collect()  # For PyPy or other GCs.
54        deadlist = [weak for weak in weaklist if weak() is None]
55        self.assertIn(len(deadlist), (n-1, n), (n, len(deadlist)))
56
57    def test_derived(self):
58        # Issue 3088: if there is a threads switch inside the __init__
59        # of a threading.local derived class, the per-thread dictionary
60        # is created but not correctly set on the object.
61        # The first member set may be bogus.
62        import time
63        class Local(self._local):
64            def __init__(self):
65                time.sleep(0.01)
66        local = Local()
67
68        def f(i):
69            local.x = i
70            # Simply check that the variable is correctly set
71            self.assertEqual(local.x, i)
72
73        with threading_helper.start_threads(threading.Thread(target=f, args=(i,))
74                                            for i in range(10)):
75            pass
76
77    def test_derived_cycle_dealloc(self):
78        # http://bugs.python.org/issue6990
79        class Local(self._local):
80            pass
81        locals = None
82        passed = False
83        e1 = threading.Event()
84        e2 = threading.Event()
85
86        def f():
87            nonlocal passed
88            # 1) Involve Local in a cycle
89            cycle = [Local()]
90            cycle.append(cycle)
91            cycle[0].foo = 'bar'
92
93            # 2) GC the cycle (triggers threadmodule.c::local_clear
94            # before local_dealloc)
95            del cycle
96            support.gc_collect()  # For PyPy or other GCs.
97            e1.set()
98            e2.wait()
99
100            # 4) New Locals should be empty
101            passed = all(not hasattr(local, 'foo') for local in locals)
102
103        t = threading.Thread(target=f)
104        t.start()
105        e1.wait()
106
107        # 3) New Locals should recycle the original's address. Creating
108        # them in the thread overwrites the thread state and avoids the
109        # bug
110        locals = [Local() for i in range(10)]
111        e2.set()
112        t.join()
113
114        self.assertTrue(passed)
115
116    def test_arguments(self):
117        # Issue 1522237
118        class MyLocal(self._local):
119            def __init__(self, *args, **kwargs):
120                pass
121
122        MyLocal(a=1)
123        MyLocal(1)
124        self.assertRaises(TypeError, self._local, a=1)
125        self.assertRaises(TypeError, self._local, 1)
126
127    def _test_one_class(self, c):
128        self._failed = "No error message set or cleared."
129        obj = c()
130        e1 = threading.Event()
131        e2 = threading.Event()
132
133        def f1():
134            obj.x = 'foo'
135            obj.y = 'bar'
136            del obj.y
137            e1.set()
138            e2.wait()
139
140        def f2():
141            try:
142                foo = obj.x
143            except AttributeError:
144                # This is expected -- we haven't set obj.x in this thread yet!
145                self._failed = ""  # passed
146            else:
147                self._failed = ('Incorrectly got value %r from class %r\n' %
148                                (foo, c))
149                sys.stderr.write(self._failed)
150
151        t1 = threading.Thread(target=f1)
152        t1.start()
153        e1.wait()
154        t2 = threading.Thread(target=f2)
155        t2.start()
156        t2.join()
157        # The test is done; just let t1 know it can exit, and wait for it.
158        e2.set()
159        t1.join()
160
161        self.assertFalse(self._failed, self._failed)
162
163    def test_threading_local(self):
164        self._test_one_class(self._local)
165
166    def test_threading_local_subclass(self):
167        class LocalSubclass(self._local):
168            """To test that subclasses behave properly."""
169        self._test_one_class(LocalSubclass)
170
171    def _test_dict_attribute(self, cls):
172        obj = cls()
173        obj.x = 5
174        self.assertEqual(obj.__dict__, {'x': 5})
175        with self.assertRaises(AttributeError):
176            obj.__dict__ = {}
177        with self.assertRaises(AttributeError):
178            del obj.__dict__
179
180    def test_dict_attribute(self):
181        self._test_dict_attribute(self._local)
182
183    def test_dict_attribute_subclass(self):
184        class LocalSubclass(self._local):
185            """To test that subclasses behave properly."""
186        self._test_dict_attribute(LocalSubclass)
187
188    def test_cycle_collection(self):
189        class X:
190            pass
191
192        x = X()
193        x.local = self._local()
194        x.local.x = x
195        wr = weakref.ref(x)
196        del x
197        support.gc_collect()  # For PyPy or other GCs.
198        self.assertIsNone(wr())
199
200
201    def test_threading_local_clear_race(self):
202        # See https://github.com/python/cpython/issues/100892
203
204        _testcapi = import_module('_testcapi')
205        _testcapi.call_in_temporary_c_thread(lambda: None, False)
206
207        for _ in range(1000):
208            _ = threading.local()
209
210        _testcapi.join_temporary_c_thread()
211
212
213class ThreadLocalTest(unittest.TestCase, BaseLocalTest):
214    _local = _thread._local
215
216class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest):
217    _local = _threading_local.local
218
219
220def load_tests(loader, tests, pattern):
221    tests.addTest(DocTestSuite('_threading_local'))
222
223    local_orig = _threading_local.local
224    def setUp(test):
225        _threading_local.local = _thread._local
226    def tearDown(test):
227        _threading_local.local = local_orig
228    tests.addTests(DocTestSuite('_threading_local',
229                                setUp=setUp, tearDown=tearDown)
230                   )
231    return tests
232
233
234if __name__ == '__main__':
235    unittest.main()
236