1import copy 2import functools 3import json 4import os 5import re 6import sys 7import xml.etree.ElementTree as ET 8from collections import defaultdict 9from types import MethodType 10from typing import Any, List, Optional, TYPE_CHECKING, Union 11 12import pytest 13from _pytest.config import Config, filename_arg 14from _pytest.config.argparsing import Parser 15from _pytest.junitxml import _NodeReporter, bin_xml_escape, LogXML 16from _pytest.python import Module 17from _pytest.reports import TestReport 18from _pytest.stash import StashKey 19from _pytest.terminal import _get_raw_skip_reason 20 21from pytest_shard_custom import pytest_addoptions as shard_addoptions, PytestShardPlugin 22 23 24if TYPE_CHECKING: 25 from _pytest._code.code import ReprFileLocation 26 27# a lot of this file is copied from _pytest.junitxml and modified to get rerun info 28 29xml_key = StashKey["LogXMLReruns"]() 30STEPCURRENT_CACHE_DIR = "cache/stepcurrent" 31 32 33def pytest_addoption(parser: Parser) -> None: 34 group = parser.getgroup("general") 35 group.addoption("--scs", action="store", default=None, dest="stepcurrent_skip") 36 group.addoption("--sc", action="store", default=None, dest="stepcurrent") 37 group.addoption("--rs", action="store", default=None, dest="run_single") 38 39 parser.addoption("--use-main-module", action="store_true") 40 group = parser.getgroup("terminal reporting") 41 group.addoption( 42 "--junit-xml-reruns", 43 action="store", 44 dest="xmlpath_reruns", 45 metavar="path", 46 type=functools.partial(filename_arg, optname="--junit-xml-reruns"), 47 default=None, 48 help="create junit-xml style report file at given path.", 49 ) 50 group.addoption( 51 "--junit-prefix-reruns", 52 action="store", 53 metavar="str", 54 default=None, 55 help="prepend prefix to classnames in junit-xml output", 56 ) 57 parser.addini( 58 "junit_suite_name_reruns", "Test suite name for JUnit report", default="pytest" 59 ) 60 parser.addini( 61 "junit_logging_reruns", 62 "Write captured log messages to JUnit report: " 63 "one of no|log|system-out|system-err|out-err|all", 64 default="no", 65 ) 66 parser.addini( 67 "junit_log_passing_tests_reruns", 68 "Capture log information for passing tests to JUnit report: ", 69 type="bool", 70 default=True, 71 ) 72 parser.addini( 73 "junit_duration_report_reruns", 74 "Duration time to report: one of total|call", 75 default="total", 76 ) 77 parser.addini( 78 "junit_family_reruns", 79 "Emit XML for schema: one of legacy|xunit1|xunit2", 80 default="xunit2", 81 ) 82 shard_addoptions(parser) 83 84 85def pytest_configure(config: Config) -> None: 86 xmlpath = config.option.xmlpath_reruns 87 # Prevent opening xmllog on worker nodes (xdist). 88 if xmlpath and not hasattr(config, "workerinput"): 89 junit_family = config.getini("junit_family_reruns") 90 config.stash[xml_key] = LogXMLReruns( 91 xmlpath, 92 config.option.junitprefix, 93 config.getini("junit_suite_name_reruns"), 94 config.getini("junit_logging_reruns"), 95 config.getini("junit_duration_report_reruns"), 96 junit_family, 97 config.getini("junit_log_passing_tests_reruns"), 98 ) 99 config.pluginmanager.register(config.stash[xml_key]) 100 if config.getoption("stepcurrent_skip"): 101 config.option.stepcurrent = config.getoption("stepcurrent_skip") 102 if config.getoption("run_single"): 103 config.option.stepcurrent = config.getoption("run_single") 104 if config.getoption("stepcurrent"): 105 config.pluginmanager.register(StepcurrentPlugin(config), "stepcurrentplugin") 106 if config.getoption("num_shards"): 107 config.pluginmanager.register(PytestShardPlugin(config), "pytestshardplugin") 108 109 110def pytest_unconfigure(config: Config) -> None: 111 xml = config.stash.get(xml_key, None) 112 if xml: 113 del config.stash[xml_key] 114 config.pluginmanager.unregister(xml) 115 116 117class _NodeReporterReruns(_NodeReporter): 118 def _prepare_content(self, content: str, header: str) -> str: 119 return content 120 121 def _write_content(self, report: TestReport, content: str, jheader: str) -> None: 122 if content == "": 123 return 124 tag = ET.Element(jheader) 125 tag.text = bin_xml_escape(content) 126 self.append(tag) 127 128 def append_skipped(self, report: TestReport) -> None: 129 # Referenced from the below 130 # https://github.com/pytest-dev/pytest/blob/2178ee86d7c1ee93748cfb46540a6e40b4761f2d/src/_pytest/junitxml.py#L236C6-L236C6 131 # Modified to escape characters not supported by xml in the skip reason. Everything else should be the same. 132 if hasattr(report, "wasxfail"): 133 # Super here instead of the actual code so we can reduce possible divergence 134 super().append_skipped(report) 135 else: 136 assert isinstance(report.longrepr, tuple) 137 filename, lineno, skipreason = report.longrepr 138 if skipreason.startswith("Skipped: "): 139 skipreason = skipreason[9:] 140 details = f"{filename}:{lineno}: {skipreason}" 141 142 skipped = ET.Element( 143 "skipped", type="pytest.skip", message=bin_xml_escape(skipreason) 144 ) 145 skipped.text = bin_xml_escape(details) 146 self.append(skipped) 147 self.write_captured_output(report) 148 149 150class LogXMLReruns(LogXML): 151 def __init__(self, *args, **kwargs): 152 super().__init__(*args, **kwargs) 153 154 def append_rerun(self, reporter: _NodeReporter, report: TestReport) -> None: 155 if hasattr(report, "wasxfail"): 156 reporter._add_simple("skipped", "xfail-marked test passes unexpectedly") 157 else: 158 assert report.longrepr is not None 159 reprcrash: Optional[ReprFileLocation] = getattr( 160 report.longrepr, "reprcrash", None 161 ) 162 if reprcrash is not None: 163 message = reprcrash.message 164 else: 165 message = str(report.longrepr) 166 message = bin_xml_escape(message) 167 reporter._add_simple("rerun", message, str(report.longrepr)) 168 169 def pytest_runtest_logreport(self, report: TestReport) -> None: 170 super().pytest_runtest_logreport(report) 171 if report.outcome == "rerun": 172 reporter = self._opentestcase(report) 173 self.append_rerun(reporter, report) 174 if report.outcome == "skipped": 175 if isinstance(report.longrepr, tuple): 176 fspath, lineno, reason = report.longrepr 177 reason = f"{report.nodeid}: {_get_raw_skip_reason(report)}" 178 report.longrepr = (fspath, lineno, reason) 179 180 def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporterReruns: 181 nodeid: Union[str, TestReport] = getattr(report, "nodeid", report) 182 # Local hack to handle xdist report order. 183 workernode = getattr(report, "node", None) 184 185 key = nodeid, workernode 186 187 if key in self.node_reporters: 188 # TODO: breaks for --dist=each 189 return self.node_reporters[key] 190 191 reporter = _NodeReporterReruns(nodeid, self) 192 193 self.node_reporters[key] = reporter 194 self.node_reporters_ordered.append(reporter) 195 196 return reporter 197 198 199# imitating summary_failures in pytest's terminal.py 200# both hookwrapper and tryfirst to make sure this runs before pytest's 201@pytest.hookimpl(hookwrapper=True, tryfirst=True) 202def pytest_terminal_summary(terminalreporter, exitstatus, config): 203 # prints stack traces for reruns 204 if terminalreporter.config.option.tbstyle != "no": 205 reports = terminalreporter.getreports("rerun") 206 if reports: 207 terminalreporter.write_sep("=", "RERUNS") 208 if terminalreporter.config.option.tbstyle == "line": 209 for rep in reports: 210 line = terminalreporter._getcrashline(rep) 211 terminalreporter.write_line(line) 212 else: 213 for rep in reports: 214 msg = terminalreporter._getfailureheadline(rep) 215 terminalreporter.write_sep("_", msg, red=True, bold=True) 216 terminalreporter._outrep_summary(rep) 217 terminalreporter._handle_teardown_sections(rep.nodeid) 218 yield 219 220 221@pytest.hookimpl(tryfirst=True) 222def pytest_pycollect_makemodule(module_path, path, parent) -> Module: 223 if parent.config.getoption("--use-main-module"): 224 mod = Module.from_parent(parent, path=module_path) 225 mod._getobj = MethodType(lambda x: sys.modules["__main__"], mod) 226 return mod 227 228 229@pytest.hookimpl(hookwrapper=True) 230def pytest_report_teststatus(report, config): 231 # Add the test time to the verbose output, unforunately I don't think this 232 # includes setup or teardown 233 pluggy_result = yield 234 if not isinstance(report, pytest.TestReport): 235 return 236 outcome, letter, verbose = pluggy_result.get_result() 237 if verbose: 238 pluggy_result.force_result( 239 (outcome, letter, f"{verbose} [{report.duration:.4f}s]") 240 ) 241 242 243@pytest.hookimpl(trylast=True) 244def pytest_collection_modifyitems(items: List[Any]) -> None: 245 """ 246 This hook is used when rerunning disabled tests to get rid of all skipped tests 247 instead of running and skipping them N times. This avoids flooding the console 248 and XML outputs with junk. So we want this to run last when collecting tests. 249 """ 250 rerun_disabled_tests = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1" 251 if not rerun_disabled_tests: 252 return 253 254 disabled_regex = re.compile(r"(?P<test_name>.+)\s+\([^\.]+\.(?P<test_class>.+)\)") 255 disabled_tests = defaultdict(set) 256 257 # This environment has already been set by run_test before it calls pytest 258 disabled_tests_file = os.getenv("DISABLED_TESTS_FILE", "") 259 if not disabled_tests_file or not os.path.exists(disabled_tests_file): 260 return 261 262 with open(disabled_tests_file) as fp: 263 for disabled_test in json.load(fp): 264 m = disabled_regex.match(disabled_test) 265 if m: 266 test_name = m["test_name"] 267 test_class = m["test_class"] 268 disabled_tests[test_class].add(test_name) 269 270 # When rerunning disabled test, ignore all test cases that are not disabled 271 filtered_items = [] 272 273 for item in items: 274 test_name = item.name 275 test_class = item.parent.name 276 277 if ( 278 test_class not in disabled_tests 279 or test_name not in disabled_tests[test_class] 280 ): 281 continue 282 283 cpy = copy.copy(item) 284 cpy._initrequest() 285 286 filtered_items.append(cpy) 287 288 items.clear() 289 # NB: Need to edit items directly here to have the list reflected back to pytest 290 items.extend(filtered_items) 291 292 293class StepcurrentPlugin: 294 # Modified fromo _pytest/stepwise.py in order to save the currently running 295 # test instead of the last failed test 296 def __init__(self, config: Config) -> None: 297 self.config = config 298 self.report_status = "" 299 assert config.cache is not None 300 self.cache: pytest.Cache = config.cache 301 self.directory = f"{STEPCURRENT_CACHE_DIR}/{config.getoption('stepcurrent')}" 302 self.lastrun: Optional[str] = self.cache.get(self.directory, None) 303 self.initial_val = self.lastrun 304 self.skip: bool = config.getoption("stepcurrent_skip") 305 self.run_single: bool = config.getoption("run_single") 306 307 def pytest_collection_modifyitems(self, config: Config, items: List[Any]) -> None: 308 if not self.lastrun: 309 self.report_status = "Cannot find last run test, not skipping" 310 return 311 312 # check all item nodes until we find a match on last run 313 failed_index = None 314 for index, item in enumerate(items): 315 if item.nodeid == self.lastrun: 316 failed_index = index 317 if self.skip: 318 failed_index += 1 319 break 320 321 # If the previously failed test was not found among the test items, 322 # do not skip any tests. 323 if failed_index is None: 324 self.report_status = "previously run test not found, not skipping." 325 else: 326 self.report_status = f"skipping {failed_index} already run items." 327 deselected = items[:failed_index] 328 del items[:failed_index] 329 if self.run_single: 330 self.report_status += f" Running only {items[0].nodeid}" 331 deselected += items[1:] 332 del items[1:] 333 config.hook.pytest_deselected(items=deselected) 334 335 def pytest_report_collectionfinish(self) -> Optional[str]: 336 if self.config.getoption("verbose") >= 0 and self.report_status: 337 return f"stepcurrent: {self.report_status}" 338 return None 339 340 def pytest_runtest_protocol(self, item, nextitem) -> None: 341 self.lastrun = item.nodeid 342 self.cache.set(self.directory, self.lastrun) 343 344 def pytest_sessionfinish(self, session, exitstatus): 345 if exitstatus == 0 and not self.run_single: 346 self.cache.set(self.directory, self.initial_val) 347