1from __future__ import annotations 2 3from abc import abstractmethod 4from copy import copy 5from typing import Any, Iterable, Iterator 6 7from tools.testing.test_run import TestRun 8 9 10class TestPrioritizations: 11 """ 12 Describes the results of whether heuristics consider a test relevant or not. 13 14 All the different ranks of tests are disjoint, meaning a test can only be in one category, and they are only 15 declared at initialization time. 16 17 A list can be empty if a heuristic doesn't consider any tests to be in that category. 18 19 Important: Lists of tests must always be returned in a deterministic order, 20 otherwise it breaks the test sharding logic 21 """ 22 23 _original_tests: frozenset[str] 24 _test_scores: dict[TestRun, float] 25 26 def __init__( 27 self, 28 tests_being_ranked: Iterable[str], # The tests that are being prioritized. 29 scores: dict[TestRun, float], 30 ) -> None: 31 self._original_tests = frozenset(tests_being_ranked) 32 self._test_scores = {TestRun(test): 0.0 for test in self._original_tests} 33 34 for test, score in scores.items(): 35 self.set_test_score(test, score) 36 37 self.validate() 38 39 def validate(self) -> None: 40 # Union all TestRuns that contain include/exclude pairs 41 all_tests = self._test_scores.keys() 42 files = {} 43 for test in all_tests: 44 if test.test_file not in files: 45 files[test.test_file] = copy(test) 46 else: 47 assert ( 48 files[test.test_file] & test 49 ).is_empty(), ( 50 f"Test run `{test}` overlaps with `{files[test.test_file]}`" 51 ) 52 files[test.test_file] |= test 53 54 for test in files.values(): 55 assert ( 56 test.is_full_file() 57 ), f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that" 58 59 # Ensure that the set of tests in the TestPrioritizations is identical to the set of tests passed in 60 assert self._original_tests == set( 61 files.keys() 62 ), "The set of tests in the TestPrioritizations must be identical to the set of tests passed in" 63 64 def _traverse_scores(self) -> Iterator[tuple[float, TestRun]]: 65 # Sort by score, then alphabetically by test name 66 for test, score in sorted( 67 self._test_scores.items(), key=lambda x: (-x[1], str(x[0])) 68 ): 69 yield score, test 70 71 def set_test_score(self, test_run: TestRun, new_score: float) -> None: 72 if test_run.test_file not in self._original_tests: 73 return # We don't need this test 74 75 relevant_test_runs: list[TestRun] = [ 76 tr for tr in self._test_scores.keys() if tr & test_run and tr != test_run 77 ] 78 79 # Set the score of all the tests that are covered by test_run to the same score 80 self._test_scores[test_run] = new_score 81 # Set the score of all the tests that are not covered by test_run to original score 82 for relevant_test_run in relevant_test_runs: 83 old_score = self._test_scores[relevant_test_run] 84 del self._test_scores[relevant_test_run] 85 86 not_to_be_updated = relevant_test_run - test_run 87 if not not_to_be_updated.is_empty(): 88 self._test_scores[not_to_be_updated] = old_score 89 self.validate() 90 91 def add_test_score(self, test_run: TestRun, score_to_add: float) -> None: 92 if test_run.test_file not in self._original_tests: 93 return 94 95 relevant_test_runs: list[TestRun] = [ 96 tr for tr in self._test_scores.keys() if tr & test_run 97 ] 98 99 for relevant_test_run in relevant_test_runs: 100 old_score = self._test_scores[relevant_test_run] 101 del self._test_scores[relevant_test_run] 102 103 intersection = relevant_test_run & test_run 104 if not intersection.is_empty(): 105 self._test_scores[intersection] = old_score + score_to_add 106 107 not_to_be_updated = relevant_test_run - test_run 108 if not not_to_be_updated.is_empty(): 109 self._test_scores[not_to_be_updated] = old_score 110 111 self.validate() 112 113 def get_all_tests(self) -> list[TestRun]: 114 """Returns all tests in the TestPrioritizations""" 115 return [x[1] for x in self._traverse_scores()] 116 117 def get_top_per_tests(self, n: int) -> tuple[list[TestRun], list[TestRun]]: 118 """Divides list of tests into two based on the top n% of scores. The 119 first list is the top, and the second is the rest.""" 120 tests = [x[1] for x in self._traverse_scores()] 121 index = n * len(tests) // 100 + 1 122 return tests[:index], tests[index:] 123 124 def get_info_str(self, verbose: bool = True) -> str: 125 info = "" 126 127 for score, test in self._traverse_scores(): 128 if not verbose and score == 0: 129 continue 130 info += f" {test} ({score})\n" 131 132 return info.rstrip() 133 134 def print_info(self) -> None: 135 print(self.get_info_str()) 136 137 def get_priority_info_for_test(self, test_run: TestRun) -> dict[str, Any]: 138 """Given a failing test, returns information about it's prioritization that we want to emit in our metrics.""" 139 for idx, (score, test) in enumerate(self._traverse_scores()): 140 # Different heuristics may result in a given test file being split 141 # into different test runs, so look for the overlapping tests to 142 # find the match 143 if test & test_run: 144 return {"position": idx, "score": score} 145 raise AssertionError(f"Test run {test_run} not found") 146 147 def get_test_stats(self, test: TestRun) -> dict[str, Any]: 148 return { 149 "test_name": test.test_file, 150 "test_filters": test.get_pytest_filter(), 151 **self.get_priority_info_for_test(test), 152 "max_score": max(score for score, _ in self._traverse_scores()), 153 "min_score": min(score for score, _ in self._traverse_scores()), 154 "all_scores": { 155 str(test): score for test, score in self._test_scores.items() 156 }, 157 } 158 159 def to_json(self) -> dict[str, Any]: 160 """ 161 Returns a JSON dict that describes this TestPrioritizations object. 162 """ 163 json_dict = { 164 "_test_scores": [ 165 (test.to_json(), score) 166 for test, score in self._test_scores.items() 167 if score != 0 168 ], 169 "_original_tests": list(self._original_tests), 170 } 171 return json_dict 172 173 @staticmethod 174 def from_json(json_dict: dict[str, Any]) -> TestPrioritizations: 175 """ 176 Returns a TestPrioritizations object from a JSON dict. 177 """ 178 test_prioritizations = TestPrioritizations( 179 tests_being_ranked=json_dict["_original_tests"], 180 scores={ 181 TestRun.from_json(testrun_json): score 182 for testrun_json, score in json_dict["_test_scores"] 183 }, 184 ) 185 return test_prioritizations 186 187 def amend_tests(self, tests: list[str]) -> None: 188 """ 189 Removes tests that are not in the given list from the 190 TestPrioritizations. Adds tests that are in the list but not in the 191 TestPrioritizations. 192 """ 193 valid_scores = { 194 test: score 195 for test, score in self._test_scores.items() 196 if test.test_file in tests 197 } 198 self._test_scores = valid_scores 199 200 for test in tests: 201 if test not in self._original_tests: 202 self._test_scores[TestRun(test)] = 0 203 self._original_tests = frozenset(tests) 204 205 self.validate() 206 207 208class AggregatedHeuristics: 209 """ 210 Aggregates the results across all heuristics. 211 212 It saves the individual results from each heuristic and exposes an aggregated view. 213 """ 214 215 _heuristic_results: dict[ 216 HeuristicInterface, TestPrioritizations 217 ] # Key is the Heuristic's name. Dicts will preserve the order of insertion, which is important for sharding 218 219 _all_tests: frozenset[str] 220 221 def __init__(self, all_tests: list[str]) -> None: 222 self._all_tests = frozenset(all_tests) 223 self._heuristic_results = {} 224 self.validate() 225 226 def validate(self) -> None: 227 for heuristic, heuristic_results in self._heuristic_results.items(): 228 heuristic_results.validate() 229 assert ( 230 heuristic_results._original_tests == self._all_tests 231 ), f"Tests in {heuristic.name} are not the same as the tests in the AggregatedHeuristics" 232 233 def add_heuristic_results( 234 self, heuristic: HeuristicInterface, heuristic_results: TestPrioritizations 235 ) -> None: 236 if heuristic in self._heuristic_results: 237 raise ValueError(f"We already have heuristics for {heuristic.name}") 238 239 self._heuristic_results[heuristic] = heuristic_results 240 self.validate() 241 242 def get_aggregated_priorities( 243 self, include_trial: bool = False 244 ) -> TestPrioritizations: 245 """ 246 Returns the aggregated priorities across all heuristics. 247 """ 248 valid_heuristics = { 249 heuristic: heuristic_results 250 for heuristic, heuristic_results in self._heuristic_results.items() 251 if not heuristic.trial_mode or include_trial 252 } 253 254 new_tp = TestPrioritizations(self._all_tests, {}) 255 256 for heuristic_results in valid_heuristics.values(): 257 for score, testrun in heuristic_results._traverse_scores(): 258 new_tp.add_test_score(testrun, score) 259 new_tp.validate() 260 return new_tp 261 262 def get_test_stats(self, test: TestRun) -> dict[str, Any]: 263 """ 264 Returns the aggregated statistics for a given test. 265 """ 266 stats: dict[str, Any] = { 267 "test_name": test.test_file, 268 "test_filters": test.get_pytest_filter(), 269 } 270 271 # Get metrics about the heuristics used 272 heuristics = [] 273 274 for heuristic, heuristic_results in self._heuristic_results.items(): 275 metrics = heuristic_results.get_priority_info_for_test(test) 276 metrics["heuristic_name"] = heuristic.name 277 metrics["trial_mode"] = heuristic.trial_mode 278 heuristics.append(metrics) 279 280 stats["heuristics"] = heuristics 281 282 stats[ 283 "aggregated" 284 ] = self.get_aggregated_priorities().get_priority_info_for_test(test) 285 286 stats["aggregated_trial"] = self.get_aggregated_priorities( 287 include_trial=True 288 ).get_priority_info_for_test(test) 289 290 return stats 291 292 def to_json(self) -> dict[str, Any]: 293 """ 294 Returns a JSON dict that describes this AggregatedHeuristics object. 295 """ 296 json_dict: dict[str, Any] = {} 297 for heuristic, heuristic_results in self._heuristic_results.items(): 298 json_dict[heuristic.name] = heuristic_results.to_json() 299 300 return json_dict 301 302 303class HeuristicInterface: 304 """ 305 Interface for all heuristics. 306 """ 307 308 description: str 309 310 # When trial mode is set to True, this heuristic's predictions will not be used 311 # to reorder tests. It's results will however be emitted in the metrics. 312 trial_mode: bool 313 314 @abstractmethod 315 def __init__(self, **kwargs: Any) -> None: 316 self.trial_mode = kwargs.get("trial_mode", False) # type: ignore[assignment] 317 318 @property 319 def name(self) -> str: 320 return self.__class__.__name__ 321 322 def __str__(self) -> str: 323 return self.name 324 325 @abstractmethod 326 def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: 327 """ 328 Returns a float ranking ranging from -1 to 1, where negative means skip, 329 positive means run, 0 means no idea, and magnitude = how confident the 330 heuristic is. Used by AggregatedHeuristicsRankings. 331 """ 332