1from test.test_importlib import util as test_util
2machinery = test_util.import_importlib('importlib.machinery')
3
4import os
5import re
6import sys
7import unittest
8import warnings
9from test.support import import_helper
10from contextlib import contextmanager
11from test.test_importlib.util import temp_module
12
13import_helper.import_module('winreg', required_on=['win'])
14from winreg import (
15    CreateKey, HKEY_CURRENT_USER,
16    SetValue, REG_SZ, KEY_ALL_ACCESS,
17    EnumKey, CloseKey, DeleteKey, OpenKey
18)
19
20def get_platform():
21    # Port of distutils.util.get_platform().
22    TARGET_TO_PLAT = {
23            'x86' : 'win32',
24            'x64' : 'win-amd64',
25            'arm' : 'win-arm32',
26        }
27    if ('VSCMD_ARG_TGT_ARCH' in os.environ and
28        os.environ['VSCMD_ARG_TGT_ARCH'] in TARGET_TO_PLAT):
29        return TARGET_TO_PLAT[os.environ['VSCMD_ARG_TGT_ARCH']]
30    elif 'amd64' in sys.version.lower():
31        return 'win-amd64'
32    elif '(arm)' in sys.version.lower():
33        return 'win-arm32'
34    elif '(arm64)' in sys.version.lower():
35        return 'win-arm64'
36    else:
37        return sys.platform
38
39def delete_registry_tree(root, subkey):
40    try:
41        hkey = OpenKey(root, subkey, access=KEY_ALL_ACCESS)
42    except OSError:
43        # subkey does not exist
44        return
45    while True:
46        try:
47            subsubkey = EnumKey(hkey, 0)
48        except OSError:
49            # no more subkeys
50            break
51        delete_registry_tree(hkey, subsubkey)
52    CloseKey(hkey)
53    DeleteKey(root, subkey)
54
55@contextmanager
56def setup_module(machinery, name, path=None):
57    if machinery.WindowsRegistryFinder.DEBUG_BUILD:
58        root = machinery.WindowsRegistryFinder.REGISTRY_KEY_DEBUG
59    else:
60        root = machinery.WindowsRegistryFinder.REGISTRY_KEY
61    key = root.format(fullname=name,
62                      sys_version='%d.%d' % sys.version_info[:2])
63    base_key = "Software\\Python\\PythonCore\\{}.{}".format(
64        sys.version_info.major, sys.version_info.minor)
65    assert key.casefold().startswith(base_key.casefold()), (
66        "expected key '{}' to start with '{}'".format(key, base_key))
67    try:
68        with temp_module(name, "a = 1") as location:
69            try:
70                OpenKey(HKEY_CURRENT_USER, base_key)
71                if machinery.WindowsRegistryFinder.DEBUG_BUILD:
72                    delete_key = os.path.dirname(key)
73                else:
74                    delete_key = key
75            except OSError:
76                delete_key = base_key
77            subkey = CreateKey(HKEY_CURRENT_USER, key)
78            if path is None:
79                path = location + ".py"
80            SetValue(subkey, "", REG_SZ, path)
81            yield
82    finally:
83        if delete_key:
84            delete_registry_tree(HKEY_CURRENT_USER, delete_key)
85
86
87@unittest.skipUnless(sys.platform.startswith('win'), 'requires Windows')
88class WindowsRegistryFinderTests:
89    # The module name is process-specific, allowing for
90    # simultaneous runs of the same test on a single machine.
91    test_module = "spamham{}".format(os.getpid())
92
93    def test_find_spec_missing(self):
94        spec = self.machinery.WindowsRegistryFinder.find_spec('spam')
95        self.assertIs(spec, None)
96
97    def test_find_module_missing(self):
98        with warnings.catch_warnings():
99            warnings.simplefilter("ignore", DeprecationWarning)
100            loader = self.machinery.WindowsRegistryFinder.find_module('spam')
101        self.assertIs(loader, None)
102
103    def test_module_found(self):
104        with setup_module(self.machinery, self.test_module):
105            with warnings.catch_warnings():
106                warnings.simplefilter("ignore", DeprecationWarning)
107                loader = self.machinery.WindowsRegistryFinder.find_module(self.test_module)
108            spec = self.machinery.WindowsRegistryFinder.find_spec(self.test_module)
109            self.assertIsNot(loader, None)
110            self.assertIsNot(spec, None)
111
112    def test_module_not_found(self):
113        with setup_module(self.machinery, self.test_module, path="."):
114            with warnings.catch_warnings():
115                warnings.simplefilter("ignore", DeprecationWarning)
116                loader = self.machinery.WindowsRegistryFinder.find_module(self.test_module)
117            spec = self.machinery.WindowsRegistryFinder.find_spec(self.test_module)
118            self.assertIsNone(loader)
119            self.assertIsNone(spec)
120
121(Frozen_WindowsRegistryFinderTests,
122 Source_WindowsRegistryFinderTests
123 ) = test_util.test_both(WindowsRegistryFinderTests, machinery=machinery)
124
125@unittest.skipUnless(sys.platform.startswith('win'), 'requires Windows')
126class WindowsExtensionSuffixTests:
127    def test_tagged_suffix(self):
128        suffixes = self.machinery.EXTENSION_SUFFIXES
129        expected_tag = ".cp{0.major}{0.minor}-{1}.pyd".format(sys.version_info,
130            re.sub('[^a-zA-Z0-9]', '_', get_platform()))
131        try:
132            untagged_i = suffixes.index(".pyd")
133        except ValueError:
134            untagged_i = suffixes.index("_d.pyd")
135            expected_tag = "_d" + expected_tag
136
137        self.assertIn(expected_tag, suffixes)
138
139        # Ensure the tags are in the correct order.
140        tagged_i = suffixes.index(expected_tag)
141        self.assertLess(tagged_i, untagged_i)
142
143(Frozen_WindowsExtensionSuffixTests,
144 Source_WindowsExtensionSuffixTests
145 ) = test_util.test_both(WindowsExtensionSuffixTests, machinery=machinery)
146
147
148@unittest.skipUnless(sys.platform.startswith('win'), 'requires Windows')
149class WindowsBootstrapPathTests(unittest.TestCase):
150    def check_join(self, expected, *inputs):
151        from importlib._bootstrap_external import _path_join
152        actual = _path_join(*inputs)
153        if expected.casefold() == actual.casefold():
154            return
155        self.assertEqual(expected, actual)
156
157    def test_path_join(self):
158        self.check_join(r"C:\A\B", "C:\\", "A", "B")
159        self.check_join(r"C:\A\B", "D:\\", "D", "C:\\", "A", "B")
160        self.check_join(r"C:\A\B", "C:\\", "A", "C:B")
161        self.check_join(r"C:\A\B", "C:\\", "A\\B")
162        self.check_join(r"C:\A\B", r"C:\A\B")
163
164        self.check_join("D:A", r"D:", "A")
165        self.check_join("D:A", r"C:\B\C", "D:", "A")
166        self.check_join("D:A", r"C:\B\C", r"D:A")
167
168        self.check_join(r"A\B\C", "A", "B", "C")
169        self.check_join(r"A\B\C", "A", r"B\C")
170        self.check_join(r"A\B/C", "A", "B/C")
171        self.check_join(r"A\B\C", "A/", "B\\", "C")
172
173        # Dots are not normalised by this function
174        self.check_join(r"A\../C", "A", "../C")
175        self.check_join(r"A.\.\B", "A.", ".", "B")
176
177        self.check_join(r"\\Server\Share\A\B\C", r"\\Server\Share", "A", "B", "C")
178        self.check_join(r"\\Server\Share\A\B\C", r"\\Server\Share", "D", r"\A", "B", "C")
179        self.check_join(r"\\Server\Share\A\B\C", r"\\Server2\Share2", "D",
180                                                 r"\\Server\Share", "A", "B", "C")
181        self.check_join(r"\\Server\Share\A\B\C", r"\\Server", r"\Share", "A", "B", "C")
182        self.check_join(r"\\Server\Share", r"\\Server\Share")
183        self.check_join(r"\\Server\Share\\", r"\\Server\Share\\")
184
185        # Handle edge cases with empty segments
186        self.check_join("C:\\A", "C:/A", "")
187        self.check_join("C:\\", "C:/", "")
188        self.check_join("C:", "C:", "")
189        self.check_join("//Server/Share\\", "//Server/Share/", "")
190        self.check_join("//Server/Share\\", "//Server/Share", "")
191
192if __name__ == '__main__':
193    unittest.main()
194