1from __future__ import annotations 2 3import json 4import os 5import re 6import subprocess 7from collections import defaultdict 8from functools import lru_cache 9from pathlib import Path 10from typing import cast, Dict, TYPE_CHECKING 11from urllib.request import Request, urlopen 12from warnings import warn 13 14 15if TYPE_CHECKING: 16 from tools.testing.test_run import TestRun 17 18REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent 19 20 21def python_test_file_to_test_name(tests: set[str]) -> set[str]: 22 prefix = f"test{os.path.sep}" 23 valid_tests = {f for f in tests if f.startswith(prefix) and f.endswith(".py")} 24 valid_tests = {f[len(prefix) : -len(".py")] for f in valid_tests} 25 26 return valid_tests 27 28 29@lru_cache(maxsize=None) 30def get_pr_number() -> int | None: 31 pr_number = os.environ.get("PR_NUMBER", "") 32 if pr_number == "": 33 re_match = re.match(r"^refs/tags/.*/(\d+)$", os.environ.get("GITHUB_REF", "")) 34 if re_match is not None: 35 pr_number = re_match.group(1) 36 if pr_number != "": 37 return int(pr_number) 38 return None 39 40 41@lru_cache(maxsize=None) 42def get_merge_base() -> str: 43 pr_number = get_pr_number() 44 if pr_number is not None: 45 github_token = os.environ.get("GITHUB_TOKEN") 46 headers = { 47 "Accept": "application/vnd.github.v3+json", 48 "Authorization": f"token {github_token}", 49 } 50 url = f"https://api.github.com/repos/pytorch/pytorch/pulls/{pr_number}" 51 with urlopen(Request(url, headers=headers)) as conn: 52 pr_info = json.loads(conn.read().decode()) 53 base = f"origin/{pr_info['base']['ref']}" 54 merge_base = ( 55 subprocess.check_output(["git", "merge-base", base, "HEAD"]) 56 .decode() 57 .strip() 58 ) 59 return merge_base 60 default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}" 61 merge_base = ( 62 subprocess.check_output(["git", "merge-base", default_branch, "HEAD"]) 63 .decode() 64 .strip() 65 ) 66 67 head = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip() 68 69 if merge_base == head: 70 # We are on the default branch, so check for changes since the last commit 71 merge_base = "HEAD^" 72 return merge_base 73 74 75def query_changed_files() -> list[str]: 76 base_commit = get_merge_base() 77 78 proc = subprocess.run( 79 ["git", "diff", "--name-only", base_commit, "HEAD"], 80 capture_output=True, 81 check=False, 82 ) 83 print(f"base_commit: {base_commit}") 84 85 if proc.returncode != 0: 86 raise RuntimeError("Unable to get changed files") 87 88 lines = proc.stdout.decode().strip().split("\n") 89 lines = [line.strip() for line in lines] 90 print(f"Changed files: {lines}") 91 return lines 92 93 94@lru_cache(maxsize=None) 95def get_git_commit_info() -> str: 96 """Gets the commit info since the last commit on the default branch.""" 97 base_commit = get_merge_base() 98 99 return ( 100 subprocess.check_output( 101 ["git", "log", f"{base_commit}..HEAD"], 102 ) 103 .decode() 104 .strip() 105 ) 106 107 108@lru_cache(maxsize=None) 109def get_issue_or_pr_body(number: int) -> str: 110 """Gets the body of an issue or PR""" 111 github_token = os.environ.get("GITHUB_TOKEN") 112 headers = { 113 "Accept": "application/vnd.github.v3+json", 114 "Authorization": f"token {github_token}", 115 } 116 # Despite the 'issues' in the link, this also works for PRs 117 url = f"https://api.github.com/repos/pytorch/pytorch/issues/{number}" 118 with urlopen(Request(url, headers=headers)) as conn: 119 body: str = json.loads(conn.read().decode())["body"] or "" 120 return body 121 122 123def normalize_ratings( 124 ratings: dict[TestRun, float], max_value: float, min_value: float = 0 125) -> dict[TestRun, float]: 126 # Takse the ratings, makes the max value into max_value, and proportionally 127 # distributes the rest of the ratings. 128 # Ex [1,2,3,4] and max_value 8 gets converted to [2,4,6,8] 129 # Assumes all rankings are >= 0 130 # min_value is what 0 gets mapped to and shifts the values accordingly. Ex 131 # [1,2,3,4], min_value 1, max_value 5 gets converted to [2,3,4,5] 132 # Don't modify in place 133 if len(ratings) == 0: 134 return ratings 135 min_rating = min(ratings.values()) 136 assert min_rating > 0 137 max_rating = max(ratings.values()) 138 assert max_rating > 0 139 normalized_ratings = {} 140 for tf, rank in ratings.items(): 141 normalized_ratings[tf] = rank / max_rating * (max_value - min_value) + min_value 142 return normalized_ratings 143 144 145def get_ratings_for_tests(file: str | Path) -> dict[str, float]: 146 path = REPO_ROOT / file 147 if not os.path.exists(path): 148 print(f"could not find path {path}") 149 return {} 150 with open(path) as f: 151 test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f)) 152 try: 153 changed_files = query_changed_files() 154 except Exception as e: 155 warn(f"Can't query changed test files due to {e}") 156 return {} 157 ratings: dict[str, float] = defaultdict(float) 158 for file in changed_files: 159 for test_file, score in test_file_ratings.get(file, {}).items(): 160 ratings[test_file] += score 161 return ratings 162 163 164def get_correlated_tests(file: str | Path) -> list[str]: 165 ratings = get_ratings_for_tests(file) 166 prioritize = sorted(ratings, key=lambda x: -ratings[x]) 167 return prioritize 168