1# For testing specific heuristics 2from __future__ import annotations 3 4import io 5import json 6import sys 7import unittest 8from pathlib import Path 9from typing import Any 10from unittest import mock 11 12 13REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent 14sys.path.append(str(REPO_ROOT)) 15 16from tools.test.heuristics.test_interface import TestTD 17from tools.testing.target_determination.determinator import TestPrioritizations 18from tools.testing.target_determination.heuristics.filepath import ( 19 file_matches_keyword, 20 get_keywords, 21) 22from tools.testing.target_determination.heuristics.historical_class_failure_correlation import ( 23 HistoricalClassFailurCorrelation, 24) 25from tools.testing.target_determination.heuristics.previously_failed_in_pr import ( 26 get_previous_failures, 27) 28from tools.testing.test_run import TestRun 29 30 31sys.path.remove(str(REPO_ROOT)) 32 33HEURISTIC_CLASS = "tools.testing.target_determination.heuristics.historical_class_failure_correlation." 34 35 36def mocked_file(contents: dict[Any, Any]) -> io.IOBase: 37 file_object = io.StringIO() 38 json.dump(contents, file_object) 39 file_object.seek(0) 40 return file_object 41 42 43def gen_historical_class_failures() -> dict[str, dict[str, float]]: 44 return { 45 "file1": { 46 "test1::classA": 0.5, 47 "test2::classA": 0.2, 48 "test5::classB": 0.1, 49 }, 50 "file2": { 51 "test1::classB": 0.3, 52 "test3::classA": 0.2, 53 "test5::classA": 1.5, 54 "test7::classC": 0.1, 55 }, 56 "file3": { 57 "test1::classC": 0.4, 58 "test4::classA": 0.2, 59 "test7::classC": 1.5, 60 "test8::classC": 0.1, 61 }, 62 } 63 64 65ALL_TESTS = [ 66 "test1", 67 "test2", 68 "test3", 69 "test4", 70 "test5", 71 "test6", 72 "test7", 73 "test8", 74] 75 76 77class TestHistoricalClassFailureCorrelation(TestTD): 78 @mock.patch( 79 HEURISTIC_CLASS + "_get_historical_test_class_correlations", 80 return_value=gen_historical_class_failures(), 81 ) 82 @mock.patch( 83 HEURISTIC_CLASS + "query_changed_files", 84 return_value=["file1"], 85 ) 86 def test_get_prediction_confidence( 87 self, 88 historical_class_failures: dict[str, dict[str, float]], 89 changed_files: list[str], 90 ) -> None: 91 tests_to_prioritize = ALL_TESTS 92 93 heuristic = HistoricalClassFailurCorrelation() 94 test_prioritizations = heuristic.get_prediction_confidence(tests_to_prioritize) 95 96 expected = TestPrioritizations( 97 tests_to_prioritize, 98 { 99 TestRun("test1::classA"): 0.25, 100 TestRun("test2::classA"): 0.1, 101 TestRun("test5::classB"): 0.05, 102 TestRun("test1", excluded=["classA"]): 0.0, 103 TestRun("test2", excluded=["classA"]): 0.0, 104 TestRun("test3"): 0.0, 105 TestRun("test4"): 0.0, 106 TestRun("test5", excluded=["classB"]): 0.0, 107 TestRun("test6"): 0.0, 108 TestRun("test7"): 0.0, 109 TestRun("test8"): 0.0, 110 }, 111 ) 112 113 self.assert_test_scores_almost_equal( 114 test_prioritizations._test_scores, expected._test_scores 115 ) 116 117 118class TestParsePrevTests(TestTD): 119 @mock.patch("os.path.exists", return_value=False) 120 def test_cache_does_not_exist(self, mock_exists: Any) -> None: 121 expected_failing_test_files: set[str] = set() 122 123 found_tests = get_previous_failures() 124 125 self.assertSetEqual(expected_failing_test_files, found_tests) 126 127 @mock.patch("os.path.exists", return_value=True) 128 @mock.patch("builtins.open", return_value=mocked_file({"": True})) 129 def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None: 130 expected_failing_test_files: set[str] = set() 131 132 found_tests = get_previous_failures() 133 134 self.assertSetEqual(expected_failing_test_files, found_tests) 135 mock_open.assert_called() 136 137 lastfailed_with_multiple_tests_per_file = { 138 "test/test_car.py::TestCar::test_num[17]": True, 139 "test/test_car.py::TestBar::test_num[25]": True, 140 "test/test_far.py::TestFar::test_fun_copy[17]": True, 141 "test/test_bar.py::TestBar::test_fun_copy[25]": True, 142 } 143 144 @mock.patch("os.path.exists", return_value=True) 145 @mock.patch( 146 "builtins.open", 147 return_value=mocked_file(lastfailed_with_multiple_tests_per_file), 148 ) 149 def test_dedupes_failing_test_files(self, mock_exists: Any, mock_open: Any) -> None: 150 expected_failing_test_files = {"test_car", "test_bar", "test_far"} 151 found_tests = get_previous_failures() 152 153 self.assertSetEqual(expected_failing_test_files, found_tests) 154 155 156class TestFilePath(TestTD): 157 def test_get_keywords(self) -> None: 158 self.assertEqual(get_keywords("test/test_car.py"), []) 159 self.assertEqual(get_keywords("test/nn/test_amp.py"), ["nn"]) 160 self.assertEqual(get_keywords("torch/nn/test_amp.py"), ["nn"]) 161 self.assertEqual( 162 get_keywords("torch/nn/mixed_precision/test_amp.py"), ["nn", "amp"] 163 ) 164 165 def test_match_keywords(self) -> None: 166 self.assertTrue(file_matches_keyword("test/quantization/test_car.py", "quant")) 167 self.assertTrue(file_matches_keyword("test/test_quantization.py", "quant")) 168 self.assertTrue(file_matches_keyword("test/nn/test_amp.py", "nn")) 169 self.assertTrue(file_matches_keyword("test/nn/test_amp.py", "amp")) 170 self.assertTrue(file_matches_keyword("test/test_onnx.py", "onnx")) 171 self.assertFalse(file_matches_keyword("test/test_onnx.py", "nn")) 172 173 def test_get_keywords_match(self) -> None: 174 def helper(test_file: str, changed_file: str) -> bool: 175 return any( 176 file_matches_keyword(test_file, x) for x in get_keywords(changed_file) 177 ) 178 179 self.assertTrue(helper("test/quantization/test_car.py", "quantize/t.py")) 180 self.assertFalse(helper("test/onnx/test_car.py", "nn/t.py")) 181 self.assertTrue(helper("test/nn/test_car.py", "nn/t.py")) 182 self.assertFalse(helper("test/nn/test_car.py", "test/b.py")) 183 self.assertTrue(helper("test/test_mixed_precision.py", "torch/amp/t.py")) 184 self.assertTrue(helper("test/test_amp.py", "torch/mixed_precision/t.py")) 185 self.assertTrue(helper("test/idk/other/random.py", "torch/idk/t.py")) 186 187 188if __name__ == "__main__": 189 unittest.main() 190