1from __future__ import annotations 2 3from typing import Any 4from warnings import warn 5 6from tools.testing.target_determination.heuristics.interface import ( 7 HeuristicInterface, 8 TestPrioritizations, 9) 10from tools.testing.target_determination.heuristics.utils import query_changed_files 11from tools.testing.test_run import TestRun 12 13 14class PublicBindings(HeuristicInterface): 15 # Literally just a heuristic for test_public_bindings. Pretty much anything 16 # that changes the public API can affect this testp 17 test_public_bindings = "test_public_bindings" 18 additional_files = ["test/allowlist_for_publicAPI.json"] 19 20 def __init__(self, **kwargs: dict[str, Any]) -> None: 21 super().__init__(**kwargs) 22 23 def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: 24 test_ratings = {} 25 try: 26 changed_files = query_changed_files() 27 except Exception as e: 28 warn(f"Can't query changed test files due to {e}") 29 changed_files = [] 30 31 if any( 32 file.startswith("torch/") or file in self.additional_files 33 for file in changed_files 34 ): 35 test_ratings[TestRun(self.test_public_bindings)] = 1.0 36 return TestPrioritizations(tests, test_ratings) 37