xref: /aosp_15_r20/external/pytorch/tools/testing/target_determination/heuristics/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import json
4import os
5import re
6import subprocess
7from collections import defaultdict
8from functools import lru_cache
9from pathlib import Path
10from typing import cast, Dict, TYPE_CHECKING
11from urllib.request import Request, urlopen
12from warnings import warn
13
14
15if TYPE_CHECKING:
16    from tools.testing.test_run import TestRun
17
18REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
19
20
21def python_test_file_to_test_name(tests: set[str]) -> set[str]:
22    prefix = f"test{os.path.sep}"
23    valid_tests = {f for f in tests if f.startswith(prefix) and f.endswith(".py")}
24    valid_tests = {f[len(prefix) : -len(".py")] for f in valid_tests}
25
26    return valid_tests
27
28
29@lru_cache(maxsize=None)
30def get_pr_number() -> int | None:
31    pr_number = os.environ.get("PR_NUMBER", "")
32    if pr_number == "":
33        re_match = re.match(r"^refs/tags/.*/(\d+)$", os.environ.get("GITHUB_REF", ""))
34        if re_match is not None:
35            pr_number = re_match.group(1)
36    if pr_number != "":
37        return int(pr_number)
38    return None
39
40
41@lru_cache(maxsize=None)
42def get_merge_base() -> str:
43    pr_number = get_pr_number()
44    if pr_number is not None:
45        github_token = os.environ.get("GITHUB_TOKEN")
46        headers = {
47            "Accept": "application/vnd.github.v3+json",
48            "Authorization": f"token {github_token}",
49        }
50        url = f"https://api.github.com/repos/pytorch/pytorch/pulls/{pr_number}"
51        with urlopen(Request(url, headers=headers)) as conn:
52            pr_info = json.loads(conn.read().decode())
53            base = f"origin/{pr_info['base']['ref']}"
54        merge_base = (
55            subprocess.check_output(["git", "merge-base", base, "HEAD"])
56            .decode()
57            .strip()
58        )
59        return merge_base
60    default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
61    merge_base = (
62        subprocess.check_output(["git", "merge-base", default_branch, "HEAD"])
63        .decode()
64        .strip()
65    )
66
67    head = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
68
69    if merge_base == head:
70        # We are on the default branch, so check for changes since the last commit
71        merge_base = "HEAD^"
72    return merge_base
73
74
75def query_changed_files() -> list[str]:
76    base_commit = get_merge_base()
77
78    proc = subprocess.run(
79        ["git", "diff", "--name-only", base_commit, "HEAD"],
80        capture_output=True,
81        check=False,
82    )
83    print(f"base_commit: {base_commit}")
84
85    if proc.returncode != 0:
86        raise RuntimeError("Unable to get changed files")
87
88    lines = proc.stdout.decode().strip().split("\n")
89    lines = [line.strip() for line in lines]
90    print(f"Changed files: {lines}")
91    return lines
92
93
94@lru_cache(maxsize=None)
95def get_git_commit_info() -> str:
96    """Gets the commit info since the last commit on the default branch."""
97    base_commit = get_merge_base()
98
99    return (
100        subprocess.check_output(
101            ["git", "log", f"{base_commit}..HEAD"],
102        )
103        .decode()
104        .strip()
105    )
106
107
108@lru_cache(maxsize=None)
109def get_issue_or_pr_body(number: int) -> str:
110    """Gets the body of an issue or PR"""
111    github_token = os.environ.get("GITHUB_TOKEN")
112    headers = {
113        "Accept": "application/vnd.github.v3+json",
114        "Authorization": f"token {github_token}",
115    }
116    # Despite the 'issues' in the link, this also works for PRs
117    url = f"https://api.github.com/repos/pytorch/pytorch/issues/{number}"
118    with urlopen(Request(url, headers=headers)) as conn:
119        body: str = json.loads(conn.read().decode())["body"] or ""
120        return body
121
122
123def normalize_ratings(
124    ratings: dict[TestRun, float], max_value: float, min_value: float = 0
125) -> dict[TestRun, float]:
126    # Takse the ratings, makes the max value into max_value, and proportionally
127    # distributes the rest of the ratings.
128    # Ex [1,2,3,4] and max_value 8 gets converted to [2,4,6,8]
129    # Assumes all rankings are >= 0
130    # min_value is what 0 gets mapped to and shifts the values accordingly.  Ex
131    # [1,2,3,4], min_value 1, max_value 5 gets converted to [2,3,4,5]
132    # Don't modify in place
133    if len(ratings) == 0:
134        return ratings
135    min_rating = min(ratings.values())
136    assert min_rating > 0
137    max_rating = max(ratings.values())
138    assert max_rating > 0
139    normalized_ratings = {}
140    for tf, rank in ratings.items():
141        normalized_ratings[tf] = rank / max_rating * (max_value - min_value) + min_value
142    return normalized_ratings
143
144
145def get_ratings_for_tests(file: str | Path) -> dict[str, float]:
146    path = REPO_ROOT / file
147    if not os.path.exists(path):
148        print(f"could not find path {path}")
149        return {}
150    with open(path) as f:
151        test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f))
152    try:
153        changed_files = query_changed_files()
154    except Exception as e:
155        warn(f"Can't query changed test files due to {e}")
156        return {}
157    ratings: dict[str, float] = defaultdict(float)
158    for file in changed_files:
159        for test_file, score in test_file_ratings.get(file, {}).items():
160            ratings[test_file] += score
161    return ratings
162
163
164def get_correlated_tests(file: str | Path) -> list[str]:
165    ratings = get_ratings_for_tests(file)
166    prioritize = sorted(ratings, key=lambda x: -ratings[x])
167    return prioritize
168