xref: /aosp_15_r20/external/pytorch/test/conftest.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import copy
2import functools
3import json
4import os
5import re
6import sys
7import xml.etree.ElementTree as ET
8from collections import defaultdict
9from types import MethodType
10from typing import Any, List, Optional, TYPE_CHECKING, Union
11
12import pytest
13from _pytest.config import Config, filename_arg
14from _pytest.config.argparsing import Parser
15from _pytest.junitxml import _NodeReporter, bin_xml_escape, LogXML
16from _pytest.python import Module
17from _pytest.reports import TestReport
18from _pytest.stash import StashKey
19from _pytest.terminal import _get_raw_skip_reason
20
21from pytest_shard_custom import pytest_addoptions as shard_addoptions, PytestShardPlugin
22
23
24if TYPE_CHECKING:
25    from _pytest._code.code import ReprFileLocation
26
27# a lot of this file is copied from _pytest.junitxml and modified to get rerun info
28
29xml_key = StashKey["LogXMLReruns"]()
30STEPCURRENT_CACHE_DIR = "cache/stepcurrent"
31
32
33def pytest_addoption(parser: Parser) -> None:
34    group = parser.getgroup("general")
35    group.addoption("--scs", action="store", default=None, dest="stepcurrent_skip")
36    group.addoption("--sc", action="store", default=None, dest="stepcurrent")
37    group.addoption("--rs", action="store", default=None, dest="run_single")
38
39    parser.addoption("--use-main-module", action="store_true")
40    group = parser.getgroup("terminal reporting")
41    group.addoption(
42        "--junit-xml-reruns",
43        action="store",
44        dest="xmlpath_reruns",
45        metavar="path",
46        type=functools.partial(filename_arg, optname="--junit-xml-reruns"),
47        default=None,
48        help="create junit-xml style report file at given path.",
49    )
50    group.addoption(
51        "--junit-prefix-reruns",
52        action="store",
53        metavar="str",
54        default=None,
55        help="prepend prefix to classnames in junit-xml output",
56    )
57    parser.addini(
58        "junit_suite_name_reruns", "Test suite name for JUnit report", default="pytest"
59    )
60    parser.addini(
61        "junit_logging_reruns",
62        "Write captured log messages to JUnit report: "
63        "one of no|log|system-out|system-err|out-err|all",
64        default="no",
65    )
66    parser.addini(
67        "junit_log_passing_tests_reruns",
68        "Capture log information for passing tests to JUnit report: ",
69        type="bool",
70        default=True,
71    )
72    parser.addini(
73        "junit_duration_report_reruns",
74        "Duration time to report: one of total|call",
75        default="total",
76    )
77    parser.addini(
78        "junit_family_reruns",
79        "Emit XML for schema: one of legacy|xunit1|xunit2",
80        default="xunit2",
81    )
82    shard_addoptions(parser)
83
84
85def pytest_configure(config: Config) -> None:
86    xmlpath = config.option.xmlpath_reruns
87    # Prevent opening xmllog on worker nodes (xdist).
88    if xmlpath and not hasattr(config, "workerinput"):
89        junit_family = config.getini("junit_family_reruns")
90        config.stash[xml_key] = LogXMLReruns(
91            xmlpath,
92            config.option.junitprefix,
93            config.getini("junit_suite_name_reruns"),
94            config.getini("junit_logging_reruns"),
95            config.getini("junit_duration_report_reruns"),
96            junit_family,
97            config.getini("junit_log_passing_tests_reruns"),
98        )
99        config.pluginmanager.register(config.stash[xml_key])
100    if config.getoption("stepcurrent_skip"):
101        config.option.stepcurrent = config.getoption("stepcurrent_skip")
102    if config.getoption("run_single"):
103        config.option.stepcurrent = config.getoption("run_single")
104    if config.getoption("stepcurrent"):
105        config.pluginmanager.register(StepcurrentPlugin(config), "stepcurrentplugin")
106    if config.getoption("num_shards"):
107        config.pluginmanager.register(PytestShardPlugin(config), "pytestshardplugin")
108
109
110def pytest_unconfigure(config: Config) -> None:
111    xml = config.stash.get(xml_key, None)
112    if xml:
113        del config.stash[xml_key]
114        config.pluginmanager.unregister(xml)
115
116
117class _NodeReporterReruns(_NodeReporter):
118    def _prepare_content(self, content: str, header: str) -> str:
119        return content
120
121    def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
122        if content == "":
123            return
124        tag = ET.Element(jheader)
125        tag.text = bin_xml_escape(content)
126        self.append(tag)
127
128    def append_skipped(self, report: TestReport) -> None:
129        # Referenced from the below
130        # https://github.com/pytest-dev/pytest/blob/2178ee86d7c1ee93748cfb46540a6e40b4761f2d/src/_pytest/junitxml.py#L236C6-L236C6
131        # Modified to escape characters not supported by xml in the skip reason.  Everything else should be the same.
132        if hasattr(report, "wasxfail"):
133            # Super here instead of the actual code so we can reduce possible divergence
134            super().append_skipped(report)
135        else:
136            assert isinstance(report.longrepr, tuple)
137            filename, lineno, skipreason = report.longrepr
138            if skipreason.startswith("Skipped: "):
139                skipreason = skipreason[9:]
140            details = f"{filename}:{lineno}: {skipreason}"
141
142            skipped = ET.Element(
143                "skipped", type="pytest.skip", message=bin_xml_escape(skipreason)
144            )
145            skipped.text = bin_xml_escape(details)
146            self.append(skipped)
147            self.write_captured_output(report)
148
149
150class LogXMLReruns(LogXML):
151    def __init__(self, *args, **kwargs):
152        super().__init__(*args, **kwargs)
153
154    def append_rerun(self, reporter: _NodeReporter, report: TestReport) -> None:
155        if hasattr(report, "wasxfail"):
156            reporter._add_simple("skipped", "xfail-marked test passes unexpectedly")
157        else:
158            assert report.longrepr is not None
159            reprcrash: Optional[ReprFileLocation] = getattr(
160                report.longrepr, "reprcrash", None
161            )
162            if reprcrash is not None:
163                message = reprcrash.message
164            else:
165                message = str(report.longrepr)
166            message = bin_xml_escape(message)
167            reporter._add_simple("rerun", message, str(report.longrepr))
168
169    def pytest_runtest_logreport(self, report: TestReport) -> None:
170        super().pytest_runtest_logreport(report)
171        if report.outcome == "rerun":
172            reporter = self._opentestcase(report)
173            self.append_rerun(reporter, report)
174        if report.outcome == "skipped":
175            if isinstance(report.longrepr, tuple):
176                fspath, lineno, reason = report.longrepr
177                reason = f"{report.nodeid}: {_get_raw_skip_reason(report)}"
178                report.longrepr = (fspath, lineno, reason)
179
180    def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporterReruns:
181        nodeid: Union[str, TestReport] = getattr(report, "nodeid", report)
182        # Local hack to handle xdist report order.
183        workernode = getattr(report, "node", None)
184
185        key = nodeid, workernode
186
187        if key in self.node_reporters:
188            # TODO: breaks for --dist=each
189            return self.node_reporters[key]
190
191        reporter = _NodeReporterReruns(nodeid, self)
192
193        self.node_reporters[key] = reporter
194        self.node_reporters_ordered.append(reporter)
195
196        return reporter
197
198
199# imitating summary_failures in pytest's terminal.py
200# both hookwrapper and tryfirst to make sure this runs before pytest's
201@pytest.hookimpl(hookwrapper=True, tryfirst=True)
202def pytest_terminal_summary(terminalreporter, exitstatus, config):
203    # prints stack traces for reruns
204    if terminalreporter.config.option.tbstyle != "no":
205        reports = terminalreporter.getreports("rerun")
206        if reports:
207            terminalreporter.write_sep("=", "RERUNS")
208            if terminalreporter.config.option.tbstyle == "line":
209                for rep in reports:
210                    line = terminalreporter._getcrashline(rep)
211                    terminalreporter.write_line(line)
212            else:
213                for rep in reports:
214                    msg = terminalreporter._getfailureheadline(rep)
215                    terminalreporter.write_sep("_", msg, red=True, bold=True)
216                    terminalreporter._outrep_summary(rep)
217                    terminalreporter._handle_teardown_sections(rep.nodeid)
218    yield
219
220
221@pytest.hookimpl(tryfirst=True)
222def pytest_pycollect_makemodule(module_path, path, parent) -> Module:
223    if parent.config.getoption("--use-main-module"):
224        mod = Module.from_parent(parent, path=module_path)
225        mod._getobj = MethodType(lambda x: sys.modules["__main__"], mod)
226        return mod
227
228
229@pytest.hookimpl(hookwrapper=True)
230def pytest_report_teststatus(report, config):
231    # Add the test time to the verbose output, unforunately I don't think this
232    # includes setup or teardown
233    pluggy_result = yield
234    if not isinstance(report, pytest.TestReport):
235        return
236    outcome, letter, verbose = pluggy_result.get_result()
237    if verbose:
238        pluggy_result.force_result(
239            (outcome, letter, f"{verbose} [{report.duration:.4f}s]")
240        )
241
242
243@pytest.hookimpl(trylast=True)
244def pytest_collection_modifyitems(items: List[Any]) -> None:
245    """
246    This hook is used when rerunning disabled tests to get rid of all skipped tests
247    instead of running and skipping them N times. This avoids flooding the console
248    and XML outputs with junk. So we want this to run last when collecting tests.
249    """
250    rerun_disabled_tests = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1"
251    if not rerun_disabled_tests:
252        return
253
254    disabled_regex = re.compile(r"(?P<test_name>.+)\s+\([^\.]+\.(?P<test_class>.+)\)")
255    disabled_tests = defaultdict(set)
256
257    # This environment has already been set by run_test before it calls pytest
258    disabled_tests_file = os.getenv("DISABLED_TESTS_FILE", "")
259    if not disabled_tests_file or not os.path.exists(disabled_tests_file):
260        return
261
262    with open(disabled_tests_file) as fp:
263        for disabled_test in json.load(fp):
264            m = disabled_regex.match(disabled_test)
265            if m:
266                test_name = m["test_name"]
267                test_class = m["test_class"]
268                disabled_tests[test_class].add(test_name)
269
270    # When rerunning disabled test, ignore all test cases that are not disabled
271    filtered_items = []
272
273    for item in items:
274        test_name = item.name
275        test_class = item.parent.name
276
277        if (
278            test_class not in disabled_tests
279            or test_name not in disabled_tests[test_class]
280        ):
281            continue
282
283        cpy = copy.copy(item)
284        cpy._initrequest()
285
286        filtered_items.append(cpy)
287
288    items.clear()
289    # NB: Need to edit items directly here to have the list reflected back to pytest
290    items.extend(filtered_items)
291
292
293class StepcurrentPlugin:
294    # Modified fromo _pytest/stepwise.py in order to save the currently running
295    # test instead of the last failed test
296    def __init__(self, config: Config) -> None:
297        self.config = config
298        self.report_status = ""
299        assert config.cache is not None
300        self.cache: pytest.Cache = config.cache
301        self.directory = f"{STEPCURRENT_CACHE_DIR}/{config.getoption('stepcurrent')}"
302        self.lastrun: Optional[str] = self.cache.get(self.directory, None)
303        self.initial_val = self.lastrun
304        self.skip: bool = config.getoption("stepcurrent_skip")
305        self.run_single: bool = config.getoption("run_single")
306
307    def pytest_collection_modifyitems(self, config: Config, items: List[Any]) -> None:
308        if not self.lastrun:
309            self.report_status = "Cannot find last run test, not skipping"
310            return
311
312        # check all item nodes until we find a match on last run
313        failed_index = None
314        for index, item in enumerate(items):
315            if item.nodeid == self.lastrun:
316                failed_index = index
317                if self.skip:
318                    failed_index += 1
319                break
320
321        # If the previously failed test was not found among the test items,
322        # do not skip any tests.
323        if failed_index is None:
324            self.report_status = "previously run test not found, not skipping."
325        else:
326            self.report_status = f"skipping {failed_index} already run items."
327            deselected = items[:failed_index]
328            del items[:failed_index]
329            if self.run_single:
330                self.report_status += f" Running only {items[0].nodeid}"
331                deselected += items[1:]
332                del items[1:]
333            config.hook.pytest_deselected(items=deselected)
334
335    def pytest_report_collectionfinish(self) -> Optional[str]:
336        if self.config.getoption("verbose") >= 0 and self.report_status:
337            return f"stepcurrent: {self.report_status}"
338        return None
339
340    def pytest_runtest_protocol(self, item, nextitem) -> None:
341        self.lastrun = item.nodeid
342        self.cache.set(self.directory, self.lastrun)
343
344    def pytest_sessionfinish(self, session, exitstatus):
345        if exitstatus == 0 and not self.run_single:
346            self.cache.set(self.directory, self.initial_val)
347