1from test.test_importlib import abc, util
2
3machinery = util.import_importlib('importlib.machinery')
4
5import _imp
6import marshal
7import os.path
8import unittest
9import warnings
10
11from test.support import import_helper, REPO_ROOT, STDLIB_DIR
12
13
14def resolve_stdlib_file(name, ispkg=False):
15    assert name
16    if ispkg:
17        return os.path.join(STDLIB_DIR, *name.split('.'), '__init__.py')
18    else:
19        return os.path.join(STDLIB_DIR, *name.split('.')) + '.py'
20
21
22class FindSpecTests(abc.FinderTests):
23
24    """Test finding frozen modules."""
25
26    def find(self, name, **kwargs):
27        finder = self.machinery.FrozenImporter
28        with import_helper.frozen_modules():
29            return finder.find_spec(name, **kwargs)
30
31    def check_basic(self, spec, name, ispkg=False):
32        self.assertEqual(spec.name, name)
33        self.assertIs(spec.loader, self.machinery.FrozenImporter)
34        self.assertEqual(spec.origin, 'frozen')
35        self.assertFalse(spec.has_location)
36        if ispkg:
37            self.assertIsNotNone(spec.submodule_search_locations)
38        else:
39            self.assertIsNone(spec.submodule_search_locations)
40        self.assertIsNotNone(spec.loader_state)
41
42    def check_loader_state(self, spec, origname=None, filename=None):
43        if not filename:
44            if not origname:
45                origname = spec.name
46            filename = resolve_stdlib_file(origname)
47
48        actual = dict(vars(spec.loader_state))
49
50        # Check the rest of spec.loader_state.
51        expected = dict(
52            origname=origname,
53            filename=filename if origname else None,
54        )
55        self.assertDictEqual(actual, expected)
56
57    def check_search_locations(self, spec):
58        """This is only called when testing packages."""
59        missing = object()
60        filename = getattr(spec.loader_state, 'filename', missing)
61        origname = getattr(spec.loader_state, 'origname', None)
62        if not origname or filename is missing:
63            # We deal with this in check_loader_state().
64            return
65        if not filename:
66            expected = []
67        elif origname != spec.name and not origname.startswith('<'):
68            expected = []
69        else:
70            expected = [os.path.dirname(filename)]
71        self.assertListEqual(spec.submodule_search_locations, expected)
72
73    def test_module(self):
74        modules = [
75            '__hello__',
76            '__phello__.spam',
77            '__phello__.ham.eggs',
78        ]
79        for name in modules:
80            with self.subTest(f'{name} -> {name}'):
81                spec = self.find(name)
82                self.check_basic(spec, name)
83                self.check_loader_state(spec)
84        modules = {
85            '__hello_alias__': '__hello__',
86            '_frozen_importlib': 'importlib._bootstrap',
87        }
88        for name, origname in modules.items():
89            with self.subTest(f'{name} -> {origname}'):
90                spec = self.find(name)
91                self.check_basic(spec, name)
92                self.check_loader_state(spec, origname)
93        modules = [
94            '__phello__.__init__',
95            '__phello__.ham.__init__',
96        ]
97        for name in modules:
98            origname = '<' + name.rpartition('.')[0]
99            filename = resolve_stdlib_file(name)
100            with self.subTest(f'{name} -> {origname}'):
101                spec = self.find(name)
102                self.check_basic(spec, name)
103                self.check_loader_state(spec, origname, filename)
104        modules = {
105            '__hello_only__': ('Tools', 'freeze', 'flag.py'),
106        }
107        for name, path in modules.items():
108            origname = None
109            filename = os.path.join(REPO_ROOT, *path)
110            with self.subTest(f'{name} -> {filename}'):
111                spec = self.find(name)
112                self.check_basic(spec, name)
113                self.check_loader_state(spec, origname, filename)
114
115    def test_package(self):
116        packages = [
117            '__phello__',
118            '__phello__.ham',
119        ]
120        for name in packages:
121            filename = resolve_stdlib_file(name, ispkg=True)
122            with self.subTest(f'{name} -> {name}'):
123                spec = self.find(name)
124                self.check_basic(spec, name, ispkg=True)
125                self.check_loader_state(spec, name, filename)
126                self.check_search_locations(spec)
127        packages = {
128            '__phello_alias__': '__hello__',
129        }
130        for name, origname in packages.items():
131            filename = resolve_stdlib_file(origname, ispkg=False)
132            with self.subTest(f'{name} -> {origname}'):
133                spec = self.find(name)
134                self.check_basic(spec, name, ispkg=True)
135                self.check_loader_state(spec, origname, filename)
136                self.check_search_locations(spec)
137
138    # These are covered by test_module() and test_package().
139    test_module_in_package = None
140    test_package_in_package = None
141
142    # No easy way to test.
143    test_package_over_module = None
144
145    def test_path_ignored(self):
146        for name in ('__hello__', '__phello__', '__phello__.spam'):
147            actual = self.find(name)
148            for path in (None, object(), '', 'eggs', [], [''], ['eggs']):
149                with self.subTest((name, path)):
150                    spec = self.find(name, path=path)
151                    self.assertEqual(spec, actual)
152
153    def test_target_ignored(self):
154        imported = ('__hello__', '__phello__')
155        with import_helper.CleanImport(*imported, usefrozen=True):
156            import __hello__ as match
157            import __phello__ as nonmatch
158        name = '__hello__'
159        actual = self.find(name)
160        for target in (None, match, nonmatch, object(), 'not-a-module-object'):
161            with self.subTest(target):
162                spec = self.find(name, target=target)
163                self.assertEqual(spec, actual)
164
165    def test_failure(self):
166        spec = self.find('<not real>')
167        self.assertIsNone(spec)
168
169    def test_not_using_frozen(self):
170        finder = self.machinery.FrozenImporter
171        with import_helper.frozen_modules(enabled=False):
172            # both frozen and not frozen
173            spec1 = finder.find_spec('__hello__')
174            # only frozen
175            spec2 = finder.find_spec('__hello_only__')
176        self.assertIsNone(spec1)
177        self.assertIsNone(spec2)
178
179
180(Frozen_FindSpecTests,
181 Source_FindSpecTests
182 ) = util.test_both(FindSpecTests, machinery=machinery)
183
184
185class FinderTests(abc.FinderTests):
186
187    """Test finding frozen modules."""
188
189    def find(self, name, path=None):
190        finder = self.machinery.FrozenImporter
191        with warnings.catch_warnings():
192            warnings.simplefilter("ignore", DeprecationWarning)
193            with import_helper.frozen_modules():
194                return finder.find_module(name, path)
195
196    def test_module(self):
197        name = '__hello__'
198        loader = self.find(name)
199        self.assertTrue(hasattr(loader, 'load_module'))
200
201    def test_package(self):
202        loader = self.find('__phello__')
203        self.assertTrue(hasattr(loader, 'load_module'))
204
205    def test_module_in_package(self):
206        loader = self.find('__phello__.spam', ['__phello__'])
207        self.assertTrue(hasattr(loader, 'load_module'))
208
209    # No frozen package within another package to test with.
210    test_package_in_package = None
211
212    # No easy way to test.
213    test_package_over_module = None
214
215    def test_failure(self):
216        loader = self.find('<not real>')
217        self.assertIsNone(loader)
218
219
220(Frozen_FinderTests,
221 Source_FinderTests
222 ) = util.test_both(FinderTests, machinery=machinery)
223
224
225if __name__ == '__main__':
226    unittest.main()
227