1import argparse 2import hashlib 3import json 4import logging 5import os 6import platform 7import stat 8import subprocess 9import sys 10import urllib.error 11import urllib.request 12from pathlib import Path 13 14 15# String representing the host platform (e.g. Linux, Darwin). 16HOST_PLATFORM = platform.system() 17HOST_PLATFORM_ARCH = platform.system() + "-" + platform.processor() 18 19# PyTorch directory root 20try: 21 result = subprocess.run( 22 ["git", "rev-parse", "--show-toplevel"], 23 stdout=subprocess.PIPE, 24 check=True, 25 ) 26 PYTORCH_ROOT = result.stdout.decode("utf-8").strip() 27except subprocess.CalledProcessError: 28 # If git is not installed, compute repo root as 3 folders up from this file 29 path_ = os.path.abspath(__file__) 30 for _ in range(4): 31 path_ = os.path.dirname(path_) 32 PYTORCH_ROOT = path_ 33 34DRY_RUN = False 35 36 37def compute_file_sha256(path: str) -> str: 38 """Compute the SHA256 hash of a file and return it as a hex string.""" 39 # If the file doesn't exist, return an empty string. 40 if not os.path.exists(path): 41 return "" 42 43 hash = hashlib.sha256() 44 45 # Open the file in binary mode and hash it. 46 with open(path, "rb") as f: 47 for b in f: 48 hash.update(b) 49 50 # Return the hash as a hexadecimal string. 51 return hash.hexdigest() 52 53 54def report_download_progress( 55 chunk_number: int, chunk_size: int, file_size: int 56) -> None: 57 """ 58 Pretty printer for file download progress. 59 """ 60 if file_size != -1: 61 percent = min(1, (chunk_number * chunk_size) / file_size) 62 bar = "#" * int(64 * percent) 63 sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%") 64 65 66def check(binary_path: Path, reference_hash: str) -> bool: 67 """Check whether the binary exists and is the right one. 68 69 If there is hash difference, delete the actual binary. 70 """ 71 if not binary_path.exists(): 72 logging.info("%s does not exist.", binary_path) 73 return False 74 75 existing_binary_hash = compute_file_sha256(str(binary_path)) 76 if existing_binary_hash == reference_hash: 77 return True 78 79 logging.warning( 80 """\ 81Found binary hash does not match reference! 82 83Found hash: %s 84Reference hash: %s 85 86Deleting %s just to be safe. 87""", 88 existing_binary_hash, 89 reference_hash, 90 binary_path, 91 ) 92 if DRY_RUN: 93 logging.critical( 94 "In dry run mode, so not actually deleting the binary. But consider deleting it ASAP!" 95 ) 96 return False 97 98 try: 99 binary_path.unlink() 100 except OSError as e: 101 logging.critical("Failed to delete binary: %s", e) 102 logging.critical( 103 "Delete this binary as soon as possible and do not execute it!" 104 ) 105 106 return False 107 108 109def download( 110 name: str, 111 output_dir: str, 112 url: str, 113 reference_bin_hash: str, 114) -> bool: 115 """ 116 Download a platform-appropriate binary if one doesn't already exist at the expected location and verifies 117 that it is the right binary by checking its SHA256 hash against the expected hash. 118 """ 119 # First check if we need to do anything 120 binary_path = Path(output_dir, name) 121 if check(binary_path, reference_bin_hash): 122 logging.info("Correct binary already exists at %s. Exiting.", binary_path) 123 return True 124 125 # Create the output folder 126 binary_path.parent.mkdir(parents=True, exist_ok=True) 127 128 # Download the binary 129 logging.info("Downloading %s to %s", url, binary_path) 130 131 if DRY_RUN: 132 logging.info("Exiting as there is nothing left to do in dry run mode") 133 return True 134 135 urllib.request.urlretrieve( 136 url, 137 binary_path, 138 reporthook=report_download_progress if sys.stdout.isatty() else None, 139 ) 140 141 logging.info("Downloaded %s successfully.", name) 142 143 # Check the downloaded binary 144 if not check(binary_path, reference_bin_hash): 145 logging.critical("Downloaded binary %s failed its hash check", name) 146 return False 147 148 # Ensure that executable bits are set 149 mode = os.stat(binary_path).st_mode 150 mode |= stat.S_IXUSR 151 os.chmod(binary_path, mode) 152 153 logging.info("Using %s located at %s", name, binary_path) 154 return True 155 156 157if __name__ == "__main__": 158 parser = argparse.ArgumentParser( 159 description="downloads and checks binaries from s3", 160 ) 161 parser.add_argument( 162 "--config-json", 163 required=True, 164 help="Path to config json that describes where to find binaries and hashes", 165 ) 166 parser.add_argument( 167 "--linter", 168 required=True, 169 help="Which linter to initialize from the config json", 170 ) 171 parser.add_argument( 172 "--output-dir", 173 required=True, 174 help="place to put the binary", 175 ) 176 parser.add_argument( 177 "--output-name", 178 required=True, 179 help="name of binary", 180 ) 181 parser.add_argument( 182 "--dry-run", 183 default=False, 184 help="do not download, just print what would be done", 185 ) 186 187 args = parser.parse_args() 188 if args.dry_run == "0": 189 DRY_RUN = False 190 else: 191 DRY_RUN = True 192 193 logging.basicConfig( 194 format="[DRY_RUN] %(levelname)s: %(message)s" 195 if DRY_RUN 196 else "%(levelname)s: %(message)s", 197 level=logging.INFO, 198 stream=sys.stderr, 199 ) 200 201 config = json.load(open(args.config_json)) 202 config = config[args.linter] 203 204 # Allow processor specific binaries for platform (namely Intel and M1 binaries for MacOS) 205 host_platform = HOST_PLATFORM if HOST_PLATFORM in config else HOST_PLATFORM_ARCH 206 # If the host platform is not in platform_to_hash, it is unsupported. 207 if host_platform not in config: 208 logging.error("Unsupported platform: %s/%s", HOST_PLATFORM, HOST_PLATFORM_ARCH) 209 sys.exit(1) 210 211 url = config[host_platform]["download_url"] 212 hash = config[host_platform]["hash"] 213 214 ok = download(args.output_name, args.output_dir, url, hash) 215 if not ok: 216 logging.critical("Unable to initialize %s", args.linter) 217 sys.exit(1) 218