xref: /aosp_15_r20/external/pytorch/tools/testing/target_determination/heuristics/interface.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from abc import abstractmethod
4from copy import copy
5from typing import Any, Iterable, Iterator
6
7from tools.testing.test_run import TestRun
8
9
10class TestPrioritizations:
11    """
12    Describes the results of whether heuristics consider a test relevant or not.
13
14    All the different ranks of tests are disjoint, meaning a test can only be in one category, and they are only
15    declared at initialization time.
16
17    A list can be empty if a heuristic doesn't consider any tests to be in that category.
18
19    Important: Lists of tests must always be returned in a deterministic order,
20               otherwise it breaks the test sharding logic
21    """
22
23    _original_tests: frozenset[str]
24    _test_scores: dict[TestRun, float]
25
26    def __init__(
27        self,
28        tests_being_ranked: Iterable[str],  # The tests that are being prioritized.
29        scores: dict[TestRun, float],
30    ) -> None:
31        self._original_tests = frozenset(tests_being_ranked)
32        self._test_scores = {TestRun(test): 0.0 for test in self._original_tests}
33
34        for test, score in scores.items():
35            self.set_test_score(test, score)
36
37        self.validate()
38
39    def validate(self) -> None:
40        # Union all TestRuns that contain include/exclude pairs
41        all_tests = self._test_scores.keys()
42        files = {}
43        for test in all_tests:
44            if test.test_file not in files:
45                files[test.test_file] = copy(test)
46            else:
47                assert (
48                    files[test.test_file] & test
49                ).is_empty(), (
50                    f"Test run `{test}` overlaps with `{files[test.test_file]}`"
51                )
52                files[test.test_file] |= test
53
54        for test in files.values():
55            assert (
56                test.is_full_file()
57            ), f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that"
58
59        # Ensure that the set of tests in the TestPrioritizations is identical to the set of tests passed in
60        assert self._original_tests == set(
61            files.keys()
62        ), "The set of tests in the TestPrioritizations must be identical to the set of tests passed in"
63
64    def _traverse_scores(self) -> Iterator[tuple[float, TestRun]]:
65        # Sort by score, then alphabetically by test name
66        for test, score in sorted(
67            self._test_scores.items(), key=lambda x: (-x[1], str(x[0]))
68        ):
69            yield score, test
70
71    def set_test_score(self, test_run: TestRun, new_score: float) -> None:
72        if test_run.test_file not in self._original_tests:
73            return  # We don't need this test
74
75        relevant_test_runs: list[TestRun] = [
76            tr for tr in self._test_scores.keys() if tr & test_run and tr != test_run
77        ]
78
79        # Set the score of all the tests that are covered by test_run to the same score
80        self._test_scores[test_run] = new_score
81        # Set the score of all the tests that are not covered by test_run to original score
82        for relevant_test_run in relevant_test_runs:
83            old_score = self._test_scores[relevant_test_run]
84            del self._test_scores[relevant_test_run]
85
86            not_to_be_updated = relevant_test_run - test_run
87            if not not_to_be_updated.is_empty():
88                self._test_scores[not_to_be_updated] = old_score
89        self.validate()
90
91    def add_test_score(self, test_run: TestRun, score_to_add: float) -> None:
92        if test_run.test_file not in self._original_tests:
93            return
94
95        relevant_test_runs: list[TestRun] = [
96            tr for tr in self._test_scores.keys() if tr & test_run
97        ]
98
99        for relevant_test_run in relevant_test_runs:
100            old_score = self._test_scores[relevant_test_run]
101            del self._test_scores[relevant_test_run]
102
103            intersection = relevant_test_run & test_run
104            if not intersection.is_empty():
105                self._test_scores[intersection] = old_score + score_to_add
106
107            not_to_be_updated = relevant_test_run - test_run
108            if not not_to_be_updated.is_empty():
109                self._test_scores[not_to_be_updated] = old_score
110
111        self.validate()
112
113    def get_all_tests(self) -> list[TestRun]:
114        """Returns all tests in the TestPrioritizations"""
115        return [x[1] for x in self._traverse_scores()]
116
117    def get_top_per_tests(self, n: int) -> tuple[list[TestRun], list[TestRun]]:
118        """Divides list of tests into two based on the top n% of scores.  The
119        first list is the top, and the second is the rest."""
120        tests = [x[1] for x in self._traverse_scores()]
121        index = n * len(tests) // 100 + 1
122        return tests[:index], tests[index:]
123
124    def get_info_str(self, verbose: bool = True) -> str:
125        info = ""
126
127        for score, test in self._traverse_scores():
128            if not verbose and score == 0:
129                continue
130            info += f"  {test} ({score})\n"
131
132        return info.rstrip()
133
134    def print_info(self) -> None:
135        print(self.get_info_str())
136
137    def get_priority_info_for_test(self, test_run: TestRun) -> dict[str, Any]:
138        """Given a failing test, returns information about it's prioritization that we want to emit in our metrics."""
139        for idx, (score, test) in enumerate(self._traverse_scores()):
140            #  Different heuristics may result in a given test file being split
141            #  into different test runs, so look for the overlapping tests to
142            #  find the match
143            if test & test_run:
144                return {"position": idx, "score": score}
145        raise AssertionError(f"Test run {test_run} not found")
146
147    def get_test_stats(self, test: TestRun) -> dict[str, Any]:
148        return {
149            "test_name": test.test_file,
150            "test_filters": test.get_pytest_filter(),
151            **self.get_priority_info_for_test(test),
152            "max_score": max(score for score, _ in self._traverse_scores()),
153            "min_score": min(score for score, _ in self._traverse_scores()),
154            "all_scores": {
155                str(test): score for test, score in self._test_scores.items()
156            },
157        }
158
159    def to_json(self) -> dict[str, Any]:
160        """
161        Returns a JSON dict that describes this TestPrioritizations object.
162        """
163        json_dict = {
164            "_test_scores": [
165                (test.to_json(), score)
166                for test, score in self._test_scores.items()
167                if score != 0
168            ],
169            "_original_tests": list(self._original_tests),
170        }
171        return json_dict
172
173    @staticmethod
174    def from_json(json_dict: dict[str, Any]) -> TestPrioritizations:
175        """
176        Returns a TestPrioritizations object from a JSON dict.
177        """
178        test_prioritizations = TestPrioritizations(
179            tests_being_ranked=json_dict["_original_tests"],
180            scores={
181                TestRun.from_json(testrun_json): score
182                for testrun_json, score in json_dict["_test_scores"]
183            },
184        )
185        return test_prioritizations
186
187    def amend_tests(self, tests: list[str]) -> None:
188        """
189        Removes tests that are not in the given list from the
190        TestPrioritizations.  Adds tests that are in the list but not in the
191        TestPrioritizations.
192        """
193        valid_scores = {
194            test: score
195            for test, score in self._test_scores.items()
196            if test.test_file in tests
197        }
198        self._test_scores = valid_scores
199
200        for test in tests:
201            if test not in self._original_tests:
202                self._test_scores[TestRun(test)] = 0
203        self._original_tests = frozenset(tests)
204
205        self.validate()
206
207
208class AggregatedHeuristics:
209    """
210    Aggregates the results across all heuristics.
211
212    It saves the individual results from each heuristic and exposes an aggregated view.
213    """
214
215    _heuristic_results: dict[
216        HeuristicInterface, TestPrioritizations
217    ]  # Key is the Heuristic's name. Dicts will preserve the order of insertion, which is important for sharding
218
219    _all_tests: frozenset[str]
220
221    def __init__(self, all_tests: list[str]) -> None:
222        self._all_tests = frozenset(all_tests)
223        self._heuristic_results = {}
224        self.validate()
225
226    def validate(self) -> None:
227        for heuristic, heuristic_results in self._heuristic_results.items():
228            heuristic_results.validate()
229            assert (
230                heuristic_results._original_tests == self._all_tests
231            ), f"Tests in {heuristic.name} are not the same as the tests in the AggregatedHeuristics"
232
233    def add_heuristic_results(
234        self, heuristic: HeuristicInterface, heuristic_results: TestPrioritizations
235    ) -> None:
236        if heuristic in self._heuristic_results:
237            raise ValueError(f"We already have heuristics for {heuristic.name}")
238
239        self._heuristic_results[heuristic] = heuristic_results
240        self.validate()
241
242    def get_aggregated_priorities(
243        self, include_trial: bool = False
244    ) -> TestPrioritizations:
245        """
246        Returns the aggregated priorities across all heuristics.
247        """
248        valid_heuristics = {
249            heuristic: heuristic_results
250            for heuristic, heuristic_results in self._heuristic_results.items()
251            if not heuristic.trial_mode or include_trial
252        }
253
254        new_tp = TestPrioritizations(self._all_tests, {})
255
256        for heuristic_results in valid_heuristics.values():
257            for score, testrun in heuristic_results._traverse_scores():
258                new_tp.add_test_score(testrun, score)
259        new_tp.validate()
260        return new_tp
261
262    def get_test_stats(self, test: TestRun) -> dict[str, Any]:
263        """
264        Returns the aggregated statistics for a given test.
265        """
266        stats: dict[str, Any] = {
267            "test_name": test.test_file,
268            "test_filters": test.get_pytest_filter(),
269        }
270
271        # Get metrics about the heuristics used
272        heuristics = []
273
274        for heuristic, heuristic_results in self._heuristic_results.items():
275            metrics = heuristic_results.get_priority_info_for_test(test)
276            metrics["heuristic_name"] = heuristic.name
277            metrics["trial_mode"] = heuristic.trial_mode
278            heuristics.append(metrics)
279
280        stats["heuristics"] = heuristics
281
282        stats[
283            "aggregated"
284        ] = self.get_aggregated_priorities().get_priority_info_for_test(test)
285
286        stats["aggregated_trial"] = self.get_aggregated_priorities(
287            include_trial=True
288        ).get_priority_info_for_test(test)
289
290        return stats
291
292    def to_json(self) -> dict[str, Any]:
293        """
294        Returns a JSON dict that describes this AggregatedHeuristics object.
295        """
296        json_dict: dict[str, Any] = {}
297        for heuristic, heuristic_results in self._heuristic_results.items():
298            json_dict[heuristic.name] = heuristic_results.to_json()
299
300        return json_dict
301
302
303class HeuristicInterface:
304    """
305    Interface for all heuristics.
306    """
307
308    description: str
309
310    # When trial mode is set to True, this heuristic's predictions will not be used
311    # to reorder tests. It's results will however be emitted in the metrics.
312    trial_mode: bool
313
314    @abstractmethod
315    def __init__(self, **kwargs: Any) -> None:
316        self.trial_mode = kwargs.get("trial_mode", False)  # type: ignore[assignment]
317
318    @property
319    def name(self) -> str:
320        return self.__class__.__name__
321
322    def __str__(self) -> str:
323        return self.name
324
325    @abstractmethod
326    def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations:
327        """
328        Returns a float ranking ranging from -1 to 1, where negative means skip,
329        positive means run, 0 means no idea, and magnitude = how confident the
330        heuristic is. Used by AggregatedHeuristicsRankings.
331        """
332