xref: /aosp_15_r20/external/pytorch/tools/linter/adapters/s3_init.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import hashlib
3import json
4import logging
5import os
6import platform
7import stat
8import subprocess
9import sys
10import urllib.error
11import urllib.request
12from pathlib import Path
13
14
15# String representing the host platform (e.g. Linux, Darwin).
16HOST_PLATFORM = platform.system()
17HOST_PLATFORM_ARCH = platform.system() + "-" + platform.processor()
18
19# PyTorch directory root
20try:
21    result = subprocess.run(
22        ["git", "rev-parse", "--show-toplevel"],
23        stdout=subprocess.PIPE,
24        check=True,
25    )
26    PYTORCH_ROOT = result.stdout.decode("utf-8").strip()
27except subprocess.CalledProcessError:
28    # If git is not installed, compute repo root as 3 folders up from this file
29    path_ = os.path.abspath(__file__)
30    for _ in range(4):
31        path_ = os.path.dirname(path_)
32    PYTORCH_ROOT = path_
33
34DRY_RUN = False
35
36
37def compute_file_sha256(path: str) -> str:
38    """Compute the SHA256 hash of a file and return it as a hex string."""
39    # If the file doesn't exist, return an empty string.
40    if not os.path.exists(path):
41        return ""
42
43    hash = hashlib.sha256()
44
45    # Open the file in binary mode and hash it.
46    with open(path, "rb") as f:
47        for b in f:
48            hash.update(b)
49
50    # Return the hash as a hexadecimal string.
51    return hash.hexdigest()
52
53
54def report_download_progress(
55    chunk_number: int, chunk_size: int, file_size: int
56) -> None:
57    """
58    Pretty printer for file download progress.
59    """
60    if file_size != -1:
61        percent = min(1, (chunk_number * chunk_size) / file_size)
62        bar = "#" * int(64 * percent)
63        sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%")
64
65
66def check(binary_path: Path, reference_hash: str) -> bool:
67    """Check whether the binary exists and is the right one.
68
69    If there is hash difference, delete the actual binary.
70    """
71    if not binary_path.exists():
72        logging.info("%s does not exist.", binary_path)
73        return False
74
75    existing_binary_hash = compute_file_sha256(str(binary_path))
76    if existing_binary_hash == reference_hash:
77        return True
78
79    logging.warning(
80        """\
81Found binary hash does not match reference!
82
83Found hash: %s
84Reference hash: %s
85
86Deleting %s just to be safe.
87""",
88        existing_binary_hash,
89        reference_hash,
90        binary_path,
91    )
92    if DRY_RUN:
93        logging.critical(
94            "In dry run mode, so not actually deleting the binary. But consider deleting it ASAP!"
95        )
96        return False
97
98    try:
99        binary_path.unlink()
100    except OSError as e:
101        logging.critical("Failed to delete binary: %s", e)
102        logging.critical(
103            "Delete this binary as soon as possible and do not execute it!"
104        )
105
106    return False
107
108
109def download(
110    name: str,
111    output_dir: str,
112    url: str,
113    reference_bin_hash: str,
114) -> bool:
115    """
116    Download a platform-appropriate binary if one doesn't already exist at the expected location and verifies
117    that it is the right binary by checking its SHA256 hash against the expected hash.
118    """
119    # First check if we need to do anything
120    binary_path = Path(output_dir, name)
121    if check(binary_path, reference_bin_hash):
122        logging.info("Correct binary already exists at %s. Exiting.", binary_path)
123        return True
124
125    # Create the output folder
126    binary_path.parent.mkdir(parents=True, exist_ok=True)
127
128    # Download the binary
129    logging.info("Downloading %s to %s", url, binary_path)
130
131    if DRY_RUN:
132        logging.info("Exiting as there is nothing left to do in dry run mode")
133        return True
134
135    urllib.request.urlretrieve(
136        url,
137        binary_path,
138        reporthook=report_download_progress if sys.stdout.isatty() else None,
139    )
140
141    logging.info("Downloaded %s successfully.", name)
142
143    # Check the downloaded binary
144    if not check(binary_path, reference_bin_hash):
145        logging.critical("Downloaded binary %s failed its hash check", name)
146        return False
147
148    # Ensure that executable bits are set
149    mode = os.stat(binary_path).st_mode
150    mode |= stat.S_IXUSR
151    os.chmod(binary_path, mode)
152
153    logging.info("Using %s located at %s", name, binary_path)
154    return True
155
156
157if __name__ == "__main__":
158    parser = argparse.ArgumentParser(
159        description="downloads and checks binaries from s3",
160    )
161    parser.add_argument(
162        "--config-json",
163        required=True,
164        help="Path to config json that describes where to find binaries and hashes",
165    )
166    parser.add_argument(
167        "--linter",
168        required=True,
169        help="Which linter to initialize from the config json",
170    )
171    parser.add_argument(
172        "--output-dir",
173        required=True,
174        help="place to put the binary",
175    )
176    parser.add_argument(
177        "--output-name",
178        required=True,
179        help="name of binary",
180    )
181    parser.add_argument(
182        "--dry-run",
183        default=False,
184        help="do not download, just print what would be done",
185    )
186
187    args = parser.parse_args()
188    if args.dry_run == "0":
189        DRY_RUN = False
190    else:
191        DRY_RUN = True
192
193    logging.basicConfig(
194        format="[DRY_RUN] %(levelname)s: %(message)s"
195        if DRY_RUN
196        else "%(levelname)s: %(message)s",
197        level=logging.INFO,
198        stream=sys.stderr,
199    )
200
201    config = json.load(open(args.config_json))
202    config = config[args.linter]
203
204    # Allow processor specific binaries for platform (namely Intel and M1 binaries for MacOS)
205    host_platform = HOST_PLATFORM if HOST_PLATFORM in config else HOST_PLATFORM_ARCH
206    # If the host platform is not in platform_to_hash, it is unsupported.
207    if host_platform not in config:
208        logging.error("Unsupported platform: %s/%s", HOST_PLATFORM, HOST_PLATFORM_ARCH)
209        sys.exit(1)
210
211    url = config[host_platform]["download_url"]
212    hash = config[host_platform]["hash"]
213
214    ok = download(args.output_name, args.output_dir, url, hash)
215    if not ok:
216        logging.critical("Unable to initialize %s", args.linter)
217        sys.exit(1)
218