xref: /aosp_15_r20/external/pytorch/scripts/compile_tests/download_reports.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import json
2import os
3import pprint
4import re
5import subprocess
6
7import requests
8
9
10CONFIGS = {
11    "dynamo39": {
12        "linux-focal-py3.9-clang10 / test (dynamo, 1, 3, linux.2xlarge)",
13        "linux-focal-py3.9-clang10 / test (dynamo, 2, 3, linux.2xlarge)",
14        "linux-focal-py3.9-clang10 / test (dynamo, 3, 3, linux.2xlarge)",
15    },
16    "dynamo311": {
17        "linux-focal-py3.11-clang10 / test (dynamo, 1, 3, linux.2xlarge)",
18        "linux-focal-py3.11-clang10 / test (dynamo, 2, 3, linux.2xlarge)",
19        "linux-focal-py3.11-clang10 / test (dynamo, 3, 3, linux.2xlarge)",
20    },
21    "eager311": {
22        "linux-focal-py3.11-clang10 / test (default, 1, 3, linux.2xlarge)",
23        "linux-focal-py3.11-clang10 / test (default, 2, 3, linux.2xlarge)",
24        "linux-focal-py3.11-clang10 / test (default, 3, 3, linux.2xlarge)",
25    },
26}
27
28
29def download_reports(commit_sha, configs=("dynamo39", "dynamo311", "eager311")):
30    log_dir = "tmp_test_reports_" + commit_sha
31
32    def subdir_path(config):
33        return f"{log_dir}/{config}"
34
35    for config in configs:
36        assert config in CONFIGS.keys(), config
37    subdir_paths = [subdir_path(config) for config in configs]
38
39    # See which configs we haven't downloaded logs for yet
40    missing_configs = []
41    for config, path in zip(configs, subdir_paths):
42        if os.path.exists(path):
43            continue
44        missing_configs.append(config)
45    if len(missing_configs) == 0:
46        print(
47            f"All required logs appear to exist, not downloading again. Run `rm -rf {log_dir}` if this is not the case"
48        )
49        return subdir_paths
50
51    output = subprocess.check_output(
52        ["gh", "run", "list", "-c", commit_sha, "-w", "pull", "--json", "databaseId"]
53    ).decode()
54    workflow_run_id = str(json.loads(output)[0]["databaseId"])
55    output = subprocess.check_output(["gh", "run", "view", workflow_run_id])
56    workflow_jobs = parse_workflow_jobs(output)
57    print("found the following workflow jobs:")
58    pprint.pprint(workflow_jobs)
59
60    # Figure out which jobs we need to download logs for
61    required_jobs = []
62    for config in configs:
63        required_jobs.extend(list(CONFIGS[config]))
64    for job in required_jobs:
65        assert (
66            job in workflow_jobs
67        ), f"{job} not found, is the commit_sha correct? has the job finished running? The GitHub API may take a couple minutes to update."
68
69    # This page lists all artifacts.
70    listings = requests.get(
71        f"https://hud.pytorch.org/api/artifacts/s3/{workflow_run_id}"
72    ).json()
73
74    def download_report(job_name, subdir):
75        job_id = workflow_jobs[job_name]
76        for listing in listings:
77            name = listing["name"]
78            if not name.startswith("test-reports-"):
79                continue
80            if name.endswith(f"_{job_id}.zip"):
81                url = listing["url"]
82                subprocess.run(["wget", "-P", subdir, url], check=True)
83                path_to_zip = f"{subdir}/{name}"
84                dir_name = path_to_zip[:-4]
85                subprocess.run(["unzip", path_to_zip, "-d", dir_name], check=True)
86                return
87        raise AssertionError("should not be hit")
88
89    if not os.path.exists(log_dir):
90        os.mkdir(log_dir)
91
92    for config in set(configs) - set(missing_configs):
93        print(
94            f"Logs for {config} already exist, not downloading again. Run `rm -rf {subdir_path(config)}` if this is not the case."
95        )
96    for config in missing_configs:
97        subdir = subdir_path(config)
98        os.mkdir(subdir)
99        job_names = CONFIGS[config]
100        for job_name in job_names:
101            download_report(job_name, subdir)
102
103    return subdir_paths
104
105
106def parse_workflow_jobs(output):
107    result = {}
108    lines = output.decode().split("\n")
109    for line in lines:
110        match = re.search(r"(\S+ / .*) in .* \(ID (\d+)\)", line)
111        if match is None:
112            continue
113        result[match.group(1)] = match.group(2)
114    return result
115