1from __future__ import annotations 2 3from copy import copy 4from functools import total_ordering 5from typing import Any, Iterable 6 7 8class TestRun: 9 """ 10 TestRun defines the set of tests that should be run together in a single pytest invocation. 11 It'll either be a whole test file or a subset of a test file. 12 13 This class assumes that we won't always know the full set of TestClasses in a test file. 14 So it's designed to include or exclude explicitly requested TestClasses, while having accepting 15 that there will be an ambiguous set of "unknown" test classes that are not expliclty called out. 16 Those manifest as tests that haven't been explicitly excluded. 17 """ 18 19 test_file: str 20 _excluded: frozenset[str] # Tests that should be excluded from this test run 21 _included: frozenset[ 22 str 23 ] # If non-empy, only these tests should be run in this test run 24 25 def __init__( 26 self, 27 name: str, 28 excluded: Iterable[str] | None = None, 29 included: Iterable[str] | None = None, 30 ) -> None: 31 if excluded and included: 32 raise ValueError("Can't specify both included and excluded") 33 34 ins = set(included or []) 35 exs = set(excluded or []) 36 37 if "::" in name: 38 assert ( 39 not included and not excluded 40 ), "Can't specify included or excluded tests when specifying a test class in the file name" 41 self.test_file, test_class = name.split("::") 42 ins.add(test_class) 43 else: 44 self.test_file = name 45 46 self._excluded = frozenset(exs) 47 self._included = frozenset(ins) 48 49 @staticmethod 50 def empty() -> TestRun: 51 return TestRun("") 52 53 def is_empty(self) -> bool: 54 # Lack of a test_file means that this is an empty run, 55 # which means there is nothing to run. It's the zero. 56 return not self.test_file 57 58 def is_full_file(self) -> bool: 59 return not self._included and not self._excluded 60 61 def included(self) -> frozenset[str]: 62 return self._included 63 64 def excluded(self) -> frozenset[str]: 65 return self._excluded 66 67 def get_pytest_filter(self) -> str: 68 if self._included: 69 return " or ".join(sorted(self._included)) 70 elif self._excluded: 71 return f"not ({' or '.join(sorted(self._excluded))})" 72 else: 73 return "" 74 75 def contains(self, test: TestRun) -> bool: 76 if self.test_file != test.test_file: 77 return False 78 79 if self.is_full_file(): 80 return True # self contains all tests 81 82 if test.is_full_file(): 83 return False # test contains all tests, but self doesn't 84 85 # Does self exclude a subset of what test excludes? 86 if test._excluded: 87 return test._excluded.issubset(self._excluded) 88 89 # Does self include everything test includes? 90 if self._included: 91 return test._included.issubset(self._included) 92 93 # Getting to here means that test includes and self excludes 94 # Does self exclude anything test includes? If not, we're good 95 return not self._excluded.intersection(test._included) 96 97 def __copy__(self) -> TestRun: 98 return TestRun(self.test_file, excluded=self._excluded, included=self._included) 99 100 def __bool__(self) -> bool: 101 return not self.is_empty() 102 103 def __repr__(self) -> str: 104 r: str = f"RunTest({self.test_file}" 105 r += f", included: {self._included}" if self._included else "" 106 r += f", excluded: {self._excluded}" if self._excluded else "" 107 r += ")" 108 return r 109 110 def __str__(self) -> str: 111 if self.is_empty(): 112 return "Empty" 113 114 pytest_filter = self.get_pytest_filter() 115 if pytest_filter: 116 return self.test_file + ", " + pytest_filter 117 return self.test_file 118 119 def __eq__(self, other: object) -> bool: 120 if not isinstance(other, TestRun): 121 return False 122 123 ret = self.test_file == other.test_file 124 ret = ret and self._included == other._included 125 ret = ret and self._excluded == other._excluded 126 return ret 127 128 def __hash__(self) -> int: 129 return hash((self.test_file, self._included, self._excluded)) 130 131 def __or__(self, other: TestRun) -> TestRun: 132 """ 133 To OR/Union test runs means to run all the tests that either of the two runs specify. 134 """ 135 136 # Is any file empty? 137 if self.is_empty(): 138 return other 139 if other.is_empty(): 140 return copy(self) 141 142 # If not, ensure we have the same file 143 assert ( 144 self.test_file == other.test_file 145 ), f"Can't exclude {other} from {self} because they're not the same test file" 146 147 # 4 possible cases: 148 149 # 1. Either file is the full file, so union is everything 150 if self.is_full_file() or other.is_full_file(): 151 # The union is the whole file 152 return TestRun(self.test_file) 153 154 # 2. Both files only run what's in _included, so union is the union of the two sets 155 if self._included and other._included: 156 return TestRun( 157 self.test_file, included=self._included.union(other._included) 158 ) 159 160 # 3. Both files only exclude what's in _excluded, so union is the intersection of the two sets 161 if self._excluded and other._excluded: 162 return TestRun( 163 self.test_file, excluded=self._excluded.intersection(other._excluded) 164 ) 165 166 # 4. One file includes and the other excludes, so we then continue excluding the _excluded set minus 167 # whatever is in the _included set 168 included = self._included | other._included 169 excluded = self._excluded | other._excluded 170 return TestRun(self.test_file, excluded=excluded - included) 171 172 def __sub__(self, other: TestRun) -> TestRun: 173 """ 174 To subtract test runs means to run all the tests in the first run except for what the second run specifies. 175 """ 176 177 # Is any file empty? 178 if self.is_empty(): 179 return TestRun.empty() 180 if other.is_empty(): 181 return copy(self) 182 183 # Are you trying to subtract tests that don't even exist in this test run? 184 if self.test_file != other.test_file: 185 return copy(self) 186 187 # You're subtracting everything? 188 if other.is_full_file(): 189 return TestRun.empty() 190 191 def return_inclusions_or_empty(inclusions: frozenset[str]) -> TestRun: 192 if inclusions: 193 return TestRun(self.test_file, included=inclusions) 194 return TestRun.empty() 195 196 if other._included: 197 if self._included: 198 return return_inclusions_or_empty(self._included - other._included) 199 else: 200 return TestRun( 201 self.test_file, excluded=self._excluded | other._included 202 ) 203 else: 204 if self._included: 205 return return_inclusions_or_empty(self._included & other._excluded) 206 else: 207 return return_inclusions_or_empty(other._excluded - self._excluded) 208 209 def __and__(self, other: TestRun) -> TestRun: 210 if self.test_file != other.test_file: 211 return TestRun.empty() 212 213 return (self | other) - (self - other) - (other - self) 214 215 def to_json(self) -> dict[str, Any]: 216 r: dict[str, Any] = { 217 "test_file": self.test_file, 218 } 219 if self._included: 220 r["included"] = list(self._included) 221 if self._excluded: 222 r["excluded"] = list(self._excluded) 223 return r 224 225 @staticmethod 226 def from_json(json: dict[str, Any]) -> TestRun: 227 return TestRun( 228 json["test_file"], 229 included=json.get("included", []), 230 excluded=json.get("excluded", []), 231 ) 232 233 234@total_ordering 235class ShardedTest: 236 test: TestRun 237 shard: int 238 num_shards: int 239 time: float | None # In seconds 240 241 def __init__( 242 self, 243 test: TestRun | str, 244 shard: int, 245 num_shards: int, 246 time: float | None = None, 247 ) -> None: 248 if isinstance(test, str): 249 test = TestRun(test) 250 self.test = test 251 self.shard = shard 252 self.num_shards = num_shards 253 self.time = time 254 255 @property 256 def name(self) -> str: 257 return self.test.test_file 258 259 def __eq__(self, other: object) -> bool: 260 if not isinstance(other, ShardedTest): 261 return False 262 return ( 263 self.test == other.test 264 and self.shard == other.shard 265 and self.num_shards == other.num_shards 266 and self.time == other.time 267 ) 268 269 def __repr__(self) -> str: 270 ret = f"{self.test} {self.shard}/{self.num_shards}" 271 if self.time: 272 ret += f" ({self.time}s)" 273 274 return ret 275 276 def __lt__(self, other: object) -> bool: 277 if not isinstance(other, ShardedTest): 278 raise NotImplementedError 279 280 # This is how the list was implicity sorted when it was a NamedTuple 281 if self.name != other.name: 282 return self.name < other.name 283 if self.shard != other.shard: 284 return self.shard < other.shard 285 if self.num_shards != other.num_shards: 286 return self.num_shards < other.num_shards 287 288 # None is the smallest value 289 if self.time is None: 290 return True 291 if other.time is None: 292 return False 293 return self.time < other.time 294 295 def __str__(self) -> str: 296 return f"{self.test} {self.shard}/{self.num_shards}" 297 298 def get_time(self, default: float = 0) -> float: 299 return self.time if self.time is not None else default 300 301 def get_pytest_args(self) -> list[str]: 302 filter = self.test.get_pytest_filter() 303 if filter: 304 return ["-k", self.test.get_pytest_filter()] 305 return [] 306