xref: /aosp_15_r20/external/pytorch/tools/stats/upload_stats_lib.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import gzip
4import io
5import json
6import os
7import time
8import zipfile
9from pathlib import Path
10from typing import Any, Callable, Dict, List, Optional
11
12import boto3  # type: ignore[import]
13import requests
14import rockset  # type: ignore[import]
15
16
17PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch"
18S3_RESOURCE = boto3.resource("s3")
19
20# NB: In CI, a flaky test is usually retried 3 times, then the test file would be rerun
21# 2 more times
22MAX_RETRY_IN_NON_DISABLED_MODE = 3 * 3
23# NB: Rockset has an upper limit of 5000 documents in one request
24BATCH_SIZE = 5000
25
26
27def _get_request_headers() -> dict[str, str]:
28    return {
29        "Accept": "application/vnd.github.v3+json",
30        "Authorization": "token " + os.environ["GITHUB_TOKEN"],
31    }
32
33
34def _get_artifact_urls(prefix: str, workflow_run_id: int) -> dict[Path, str]:
35    """Get all workflow artifacts with 'test-report' in the name."""
36    response = requests.get(
37        f"{PYTORCH_REPO}/actions/runs/{workflow_run_id}/artifacts?per_page=100",
38        headers=_get_request_headers(),
39    )
40    artifacts = response.json()["artifacts"]
41    while "next" in response.links.keys():
42        response = requests.get(
43            response.links["next"]["url"], headers=_get_request_headers()
44        )
45        artifacts.extend(response.json()["artifacts"])
46
47    artifact_urls = {}
48    for artifact in artifacts:
49        if artifact["name"].startswith(prefix):
50            artifact_urls[Path(artifact["name"])] = artifact["archive_download_url"]
51    return artifact_urls
52
53
54def _download_artifact(
55    artifact_name: Path, artifact_url: str, workflow_run_attempt: int
56) -> Path:
57    # [Artifact run attempt]
58    # All artifacts on a workflow share a single namespace. However, we can
59    # re-run a workflow and produce a new set of artifacts. To avoid name
60    # collisions, we add `-runattempt1<run #>-` somewhere in the artifact name.
61    #
62    # This code parses out the run attempt number from the artifact name. If it
63    # doesn't match the one specified on the command line, skip it.
64    atoms = str(artifact_name).split("-")
65    for atom in atoms:
66        if atom.startswith("runattempt"):
67            found_run_attempt = int(atom[len("runattempt") :])
68            if workflow_run_attempt != found_run_attempt:
69                print(
70                    f"Skipping {artifact_name} as it is an invalid run attempt. "
71                    f"Expected {workflow_run_attempt}, found {found_run_attempt}."
72                )
73
74    print(f"Downloading {artifact_name}")
75
76    response = requests.get(artifact_url, headers=_get_request_headers())
77    with open(artifact_name, "wb") as f:
78        f.write(response.content)
79    return artifact_name
80
81
82def download_s3_artifacts(
83    prefix: str, workflow_run_id: int, workflow_run_attempt: int
84) -> list[Path]:
85    bucket = S3_RESOURCE.Bucket("gha-artifacts")
86    objs = bucket.objects.filter(
87        Prefix=f"pytorch/pytorch/{workflow_run_id}/{workflow_run_attempt}/artifact/{prefix}"
88    )
89
90    found_one = False
91    paths = []
92    for obj in objs:
93        found_one = True
94        p = Path(Path(obj.key).name)
95        print(f"Downloading {p}")
96        with open(p, "wb") as f:
97            f.write(obj.get()["Body"].read())
98        paths.append(p)
99
100    if not found_one:
101        print(
102            "::warning title=s3 artifacts not found::"
103            "Didn't find any test reports in s3, there might be a bug!"
104        )
105    return paths
106
107
108def download_gha_artifacts(
109    prefix: str, workflow_run_id: int, workflow_run_attempt: int
110) -> list[Path]:
111    artifact_urls = _get_artifact_urls(prefix, workflow_run_id)
112    paths = []
113    for name, url in artifact_urls.items():
114        paths.append(_download_artifact(Path(name), url, workflow_run_attempt))
115    return paths
116
117
118def upload_to_rockset(
119    collection: str,
120    docs: list[Any],
121    workspace: str = "commons",
122    client: Any = None,
123) -> None:
124    if not client:
125        client = rockset.RocksetClient(
126            host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
127        )
128
129    index = 0
130    while index < len(docs):
131        from_index = index
132        to_index = min(from_index + BATCH_SIZE, len(docs))
133        print(f"Writing {to_index - from_index} documents to Rockset")
134
135        client.Documents.add_documents(
136            collection=collection,
137            data=docs[from_index:to_index],
138            workspace=workspace,
139        )
140        index += BATCH_SIZE
141
142    print("Done!")
143
144
145def upload_to_dynamodb(
146    dynamodb_table: str,
147    repo: str,
148    docs: List[Any],
149    generate_partition_key: Optional[Callable[[str, Dict[str, Any]], str]],
150) -> None:
151    print(f"Writing {len(docs)} documents to DynamoDB {dynamodb_table}")
152    # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/dynamodb.html#batch-writing
153    with boto3.resource("dynamodb").Table(dynamodb_table).batch_writer() as batch:
154        for doc in docs:
155            if generate_partition_key:
156                doc["dynamoKey"] = generate_partition_key(repo, doc)
157            # This is to move away the _event_time field from Rockset, which we cannot use when
158            # reimport the data
159            doc["timestamp"] = int(round(time.time() * 1000))
160            batch.put_item(Item=doc)
161
162
163def upload_to_s3(
164    bucket_name: str,
165    key: str,
166    docs: list[dict[str, Any]],
167) -> None:
168    print(f"Writing {len(docs)} documents to S3")
169    body = io.StringIO()
170    for doc in docs:
171        json.dump(doc, body)
172        body.write("\n")
173
174    S3_RESOURCE.Object(
175        f"{bucket_name}",
176        f"{key}",
177    ).put(
178        Body=gzip.compress(body.getvalue().encode()),
179        ContentEncoding="gzip",
180        ContentType="application/json",
181    )
182    print("Done!")
183
184
185def read_from_s3(
186    bucket_name: str,
187    key: str,
188) -> list[dict[str, Any]]:
189    print(f"Reading from s3://{bucket_name}/{key}")
190    body = (
191        S3_RESOURCE.Object(
192            f"{bucket_name}",
193            f"{key}",
194        )
195        .get()["Body"]
196        .read()
197    )
198    results = gzip.decompress(body).decode().split("\n")
199    return [json.loads(result) for result in results if result]
200
201
202def upload_workflow_stats_to_s3(
203    workflow_run_id: int,
204    workflow_run_attempt: int,
205    collection: str,
206    docs: list[dict[str, Any]],
207) -> None:
208    bucket_name = "ossci-raw-job-status"
209    key = f"{collection}/{workflow_run_id}/{workflow_run_attempt}"
210    upload_to_s3(bucket_name, key, docs)
211
212
213def upload_file_to_s3(
214    file_name: str,
215    bucket: str,
216    key: str,
217) -> None:
218    """
219    Upload a local file to S3
220    """
221    print(f"Upload {file_name} to s3://{bucket}/{key}")
222    boto3.client("s3").upload_file(
223        file_name,
224        bucket,
225        key,
226    )
227
228
229def unzip(p: Path) -> None:
230    """Unzip the provided zipfile to a similarly-named directory.
231
232    Returns None if `p` is not a zipfile.
233
234    Looks like: /tmp/test-reports.zip -> /tmp/unzipped-test-reports/
235    """
236    assert p.is_file()
237    unzipped_dir = p.with_name("unzipped-" + p.stem)
238    print(f"Extracting {p} to {unzipped_dir}")
239
240    with zipfile.ZipFile(p, "r") as zip:
241        zip.extractall(unzipped_dir)
242
243
244def is_rerun_disabled_tests(tests: dict[str, dict[str, int]]) -> bool:
245    """
246    Check if the test report is coming from rerun_disabled_tests workflow where
247    each test is run multiple times
248    """
249    return all(
250        t.get("num_green", 0) + t.get("num_red", 0) > MAX_RETRY_IN_NON_DISABLED_MODE
251        for t in tests.values()
252    )
253
254
255def get_job_id(report: Path) -> int | None:
256    # [Job id in artifacts]
257    # Retrieve the job id from the report path. In our GHA workflows, we append
258    # the job id to the end of the report name, so `report` looks like:
259    #     unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
260    # and we want to get `5596745227` out of it.
261    try:
262        return int(report.parts[0].rpartition("_")[2])
263    except ValueError:
264        return None
265