xref: /aosp_15_r20/external/pytorch/tools/test/heuristics/test_heuristics.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# For testing specific heuristics
2from __future__ import annotations
3
4import io
5import json
6import sys
7import unittest
8from pathlib import Path
9from typing import Any
10from unittest import mock
11
12
13REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent
14sys.path.append(str(REPO_ROOT))
15
16from tools.test.heuristics.test_interface import TestTD
17from tools.testing.target_determination.determinator import TestPrioritizations
18from tools.testing.target_determination.heuristics.filepath import (
19    file_matches_keyword,
20    get_keywords,
21)
22from tools.testing.target_determination.heuristics.historical_class_failure_correlation import (
23    HistoricalClassFailurCorrelation,
24)
25from tools.testing.target_determination.heuristics.previously_failed_in_pr import (
26    get_previous_failures,
27)
28from tools.testing.test_run import TestRun
29
30
31sys.path.remove(str(REPO_ROOT))
32
33HEURISTIC_CLASS = "tools.testing.target_determination.heuristics.historical_class_failure_correlation."
34
35
36def mocked_file(contents: dict[Any, Any]) -> io.IOBase:
37    file_object = io.StringIO()
38    json.dump(contents, file_object)
39    file_object.seek(0)
40    return file_object
41
42
43def gen_historical_class_failures() -> dict[str, dict[str, float]]:
44    return {
45        "file1": {
46            "test1::classA": 0.5,
47            "test2::classA": 0.2,
48            "test5::classB": 0.1,
49        },
50        "file2": {
51            "test1::classB": 0.3,
52            "test3::classA": 0.2,
53            "test5::classA": 1.5,
54            "test7::classC": 0.1,
55        },
56        "file3": {
57            "test1::classC": 0.4,
58            "test4::classA": 0.2,
59            "test7::classC": 1.5,
60            "test8::classC": 0.1,
61        },
62    }
63
64
65ALL_TESTS = [
66    "test1",
67    "test2",
68    "test3",
69    "test4",
70    "test5",
71    "test6",
72    "test7",
73    "test8",
74]
75
76
77class TestHistoricalClassFailureCorrelation(TestTD):
78    @mock.patch(
79        HEURISTIC_CLASS + "_get_historical_test_class_correlations",
80        return_value=gen_historical_class_failures(),
81    )
82    @mock.patch(
83        HEURISTIC_CLASS + "query_changed_files",
84        return_value=["file1"],
85    )
86    def test_get_prediction_confidence(
87        self,
88        historical_class_failures: dict[str, dict[str, float]],
89        changed_files: list[str],
90    ) -> None:
91        tests_to_prioritize = ALL_TESTS
92
93        heuristic = HistoricalClassFailurCorrelation()
94        test_prioritizations = heuristic.get_prediction_confidence(tests_to_prioritize)
95
96        expected = TestPrioritizations(
97            tests_to_prioritize,
98            {
99                TestRun("test1::classA"): 0.25,
100                TestRun("test2::classA"): 0.1,
101                TestRun("test5::classB"): 0.05,
102                TestRun("test1", excluded=["classA"]): 0.0,
103                TestRun("test2", excluded=["classA"]): 0.0,
104                TestRun("test3"): 0.0,
105                TestRun("test4"): 0.0,
106                TestRun("test5", excluded=["classB"]): 0.0,
107                TestRun("test6"): 0.0,
108                TestRun("test7"): 0.0,
109                TestRun("test8"): 0.0,
110            },
111        )
112
113        self.assert_test_scores_almost_equal(
114            test_prioritizations._test_scores, expected._test_scores
115        )
116
117
118class TestParsePrevTests(TestTD):
119    @mock.patch("os.path.exists", return_value=False)
120    def test_cache_does_not_exist(self, mock_exists: Any) -> None:
121        expected_failing_test_files: set[str] = set()
122
123        found_tests = get_previous_failures()
124
125        self.assertSetEqual(expected_failing_test_files, found_tests)
126
127    @mock.patch("os.path.exists", return_value=True)
128    @mock.patch("builtins.open", return_value=mocked_file({"": True}))
129    def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None:
130        expected_failing_test_files: set[str] = set()
131
132        found_tests = get_previous_failures()
133
134        self.assertSetEqual(expected_failing_test_files, found_tests)
135        mock_open.assert_called()
136
137    lastfailed_with_multiple_tests_per_file = {
138        "test/test_car.py::TestCar::test_num[17]": True,
139        "test/test_car.py::TestBar::test_num[25]": True,
140        "test/test_far.py::TestFar::test_fun_copy[17]": True,
141        "test/test_bar.py::TestBar::test_fun_copy[25]": True,
142    }
143
144    @mock.patch("os.path.exists", return_value=True)
145    @mock.patch(
146        "builtins.open",
147        return_value=mocked_file(lastfailed_with_multiple_tests_per_file),
148    )
149    def test_dedupes_failing_test_files(self, mock_exists: Any, mock_open: Any) -> None:
150        expected_failing_test_files = {"test_car", "test_bar", "test_far"}
151        found_tests = get_previous_failures()
152
153        self.assertSetEqual(expected_failing_test_files, found_tests)
154
155
156class TestFilePath(TestTD):
157    def test_get_keywords(self) -> None:
158        self.assertEqual(get_keywords("test/test_car.py"), [])
159        self.assertEqual(get_keywords("test/nn/test_amp.py"), ["nn"])
160        self.assertEqual(get_keywords("torch/nn/test_amp.py"), ["nn"])
161        self.assertEqual(
162            get_keywords("torch/nn/mixed_precision/test_amp.py"), ["nn", "amp"]
163        )
164
165    def test_match_keywords(self) -> None:
166        self.assertTrue(file_matches_keyword("test/quantization/test_car.py", "quant"))
167        self.assertTrue(file_matches_keyword("test/test_quantization.py", "quant"))
168        self.assertTrue(file_matches_keyword("test/nn/test_amp.py", "nn"))
169        self.assertTrue(file_matches_keyword("test/nn/test_amp.py", "amp"))
170        self.assertTrue(file_matches_keyword("test/test_onnx.py", "onnx"))
171        self.assertFalse(file_matches_keyword("test/test_onnx.py", "nn"))
172
173    def test_get_keywords_match(self) -> None:
174        def helper(test_file: str, changed_file: str) -> bool:
175            return any(
176                file_matches_keyword(test_file, x) for x in get_keywords(changed_file)
177            )
178
179        self.assertTrue(helper("test/quantization/test_car.py", "quantize/t.py"))
180        self.assertFalse(helper("test/onnx/test_car.py", "nn/t.py"))
181        self.assertTrue(helper("test/nn/test_car.py", "nn/t.py"))
182        self.assertFalse(helper("test/nn/test_car.py", "test/b.py"))
183        self.assertTrue(helper("test/test_mixed_precision.py", "torch/amp/t.py"))
184        self.assertTrue(helper("test/test_amp.py", "torch/mixed_precision/t.py"))
185        self.assertTrue(helper("test/idk/other/random.py", "torch/idk/t.py"))
186
187
188if __name__ == "__main__":
189    unittest.main()
190