xref: /aosp_15_r20/external/pytorch/tools/testing/target_determination/heuristics/public_bindings.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from typing import Any
4from warnings import warn
5
6from tools.testing.target_determination.heuristics.interface import (
7    HeuristicInterface,
8    TestPrioritizations,
9)
10from tools.testing.target_determination.heuristics.utils import query_changed_files
11from tools.testing.test_run import TestRun
12
13
14class PublicBindings(HeuristicInterface):
15    # Literally just a heuristic for test_public_bindings.  Pretty much anything
16    # that changes the public API can affect this testp
17    test_public_bindings = "test_public_bindings"
18    additional_files = ["test/allowlist_for_publicAPI.json"]
19
20    def __init__(self, **kwargs: dict[str, Any]) -> None:
21        super().__init__(**kwargs)
22
23    def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations:
24        test_ratings = {}
25        try:
26            changed_files = query_changed_files()
27        except Exception as e:
28            warn(f"Can't query changed test files due to {e}")
29            changed_files = []
30
31        if any(
32            file.startswith("torch/") or file in self.additional_files
33            for file in changed_files
34        ):
35            test_ratings[TestRun(self.test_public_bindings)] = 1.0
36        return TestPrioritizations(tests, test_ratings)
37