xref: /aosp_15_r20/external/bazelbuild-rules_python/tests/entry_points/py_console_script_gen_test.py (revision 60517a1edbc8ecf509223e9af94a7adec7d736b8)
1#!/usr/bin/env python3
2# Copyright 2023 The Bazel Authors. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import pathlib
17import tempfile
18import textwrap
19import unittest
20
21from python.private.py_console_script_gen import run
22
23
24class RunTest(unittest.TestCase):
25    def setUp(self):
26        self.maxDiff = None
27
28    def test_no_console_scripts_error(self):
29        with tempfile.TemporaryDirectory() as tmpdir:
30            tmpdir = pathlib.Path(tmpdir)
31            outfile = tmpdir / "out.py"
32            given_contents = (
33                textwrap.dedent(
34                    """
35            [non_console_scripts]
36            foo = foo.bar:fizz
37            """
38                ).strip()
39                + "\n"
40            )
41            entry_points = tmpdir / "entry_points.txt"
42            entry_points.write_text(given_contents)
43
44            with self.assertRaises(RuntimeError) as cm:
45                run(
46                    entry_points=entry_points,
47                    out=outfile,
48                    console_script=None,
49                    console_script_guess="",
50                )
51
52        self.assertEqual(
53            "The package does not provide any console_scripts in its entry_points.txt",
54            cm.exception.args[0],
55        )
56
57    def test_no_entry_point_selected_error(self):
58        with tempfile.TemporaryDirectory() as tmpdir:
59            tmpdir = pathlib.Path(tmpdir)
60            outfile = tmpdir / "out.py"
61            given_contents = (
62                textwrap.dedent(
63                    """
64            [console_scripts]
65            foo = foo.bar:fizz
66            """
67                ).strip()
68                + "\n"
69            )
70            entry_points = tmpdir / "entry_points.txt"
71            entry_points.write_text(given_contents)
72
73            with self.assertRaises(RuntimeError) as cm:
74                run(
75                    entry_points=entry_points,
76                    out=outfile,
77                    console_script=None,
78                    console_script_guess="bar-baz",
79                )
80
81        self.assertEqual(
82            "Tried to guess that you wanted 'bar-baz', but could not find it. Please select one of the following console scripts: foo",
83            cm.exception.args[0],
84        )
85
86    def test_incorrect_entry_point(self):
87        with tempfile.TemporaryDirectory() as tmpdir:
88            tmpdir = pathlib.Path(tmpdir)
89            outfile = tmpdir / "out.py"
90            given_contents = (
91                textwrap.dedent(
92                    """
93            [console_scripts]
94            foo = foo.bar:fizz
95            bar = foo.bar:buzz
96            """
97                ).strip()
98                + "\n"
99            )
100            entry_points = tmpdir / "entry_points.txt"
101            entry_points.write_text(given_contents)
102
103            with self.assertRaises(RuntimeError) as cm:
104                run(
105                    entry_points=entry_points,
106                    out=outfile,
107                    console_script="baz",
108                    console_script_guess="",
109                )
110
111        self.assertEqual(
112            "The console_script 'baz' was not found, only the following are available: bar, foo",
113            cm.exception.args[0],
114        )
115
116    def test_a_single_entry_point(self):
117        with tempfile.TemporaryDirectory() as tmpdir:
118            tmpdir = pathlib.Path(tmpdir)
119            given_contents = (
120                textwrap.dedent(
121                    """
122            [console_scripts]
123            foo = foo.bar:baz
124            """
125                ).strip()
126                + "\n"
127            )
128            entry_points = tmpdir / "entry_points.txt"
129            entry_points.write_text(given_contents)
130            out = tmpdir / "foo.py"
131
132            run(
133                entry_points=entry_points,
134                out=out,
135                console_script=None,
136                console_script_guess="foo",
137            )
138
139            got = out.read_text()
140
141        want = textwrap.dedent(
142            """\
143        import sys
144
145        # See @rules_python//python/private:py_console_script_gen.py for explanation
146        if getattr(sys.flags, "safe_path", False):
147            # We are running on Python 3.11 and we don't need this workaround
148            pass
149        elif ".runfiles" not in sys.path[0]:
150            sys.path = sys.path[1:]
151
152        try:
153            from foo.bar import baz
154        except ImportError:
155            entries = "\\n".join(sys.path)
156            print("Printing sys.path entries for easier debugging:", file=sys.stderr)
157            print(f"sys.path is:\\n{entries}", file=sys.stderr)
158            raise
159
160        if __name__ == "__main__":
161            sys.exit(baz())
162        """
163        )
164        self.assertEqual(want, got)
165
166    def test_a_second_entry_point_class_method(self):
167        with tempfile.TemporaryDirectory() as tmpdir:
168            tmpdir = pathlib.Path(tmpdir)
169            given_contents = (
170                textwrap.dedent(
171                    """
172            [console_scripts]
173            foo = foo.bar:Bar.baz
174            bar = foo.baz:Bar.baz
175            """
176                ).strip()
177                + "\n"
178            )
179            entry_points = tmpdir / "entry_points.txt"
180            entry_points.write_text(given_contents)
181            out = tmpdir / "out.py"
182
183            run(
184                entry_points=entry_points,
185                out=out,
186                console_script="bar",
187                console_script_guess="",
188            )
189
190            got = out.read_text()
191
192        self.assertRegex(got, "from foo\.baz import Bar")
193        self.assertRegex(got, "sys\.exit\(Bar\.baz\(\)\)")
194
195
196if __name__ == "__main__":
197    unittest.main()
198