1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker# Much of the logging code here was forked from https://github.com/ezyang/ghstack 3*da0073e9SAndroid Build Coastguard Worker# Copyright (c) Edward Z. Yang <[email protected]> 4*da0073e9SAndroid Build Coastguard Worker"""Checks out the nightly development version of PyTorch and installs pre-built 5*da0073e9SAndroid Build Coastguard Workerbinaries into the repo. 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard WorkerYou can use this script to check out a new nightly branch with the following:: 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker $ ./tools/nightly.py checkout -b my-nightly-branch 10*da0073e9SAndroid Build Coastguard Worker $ conda activate pytorch-deps 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard WorkerOr if you would like to re-use an existing conda environment, you can pass in 13*da0073e9SAndroid Build Coastguard Workerthe regular environment parameters (--name or --prefix):: 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker $ ./tools/nightly.py checkout -b my-nightly-branch -n my-env 16*da0073e9SAndroid Build Coastguard Worker $ conda activate my-env 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard WorkerTo install the nightly binaries built with CUDA, you can pass in the flag --cuda:: 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker $ ./tools/nightly.py checkout -b my-nightly-branch --cuda 21*da0073e9SAndroid Build Coastguard Worker $ conda activate pytorch-deps 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard WorkerYou can also use this tool to pull the nightly commits into the current branch as 24*da0073e9SAndroid Build Coastguard Workerwell. This can be done with:: 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker $ ./tools/nightly.py pull -n my-env 27*da0073e9SAndroid Build Coastguard Worker $ conda activate my-env 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard WorkerPulling will reinstall the conda dependencies as well as the nightly binaries into 30*da0073e9SAndroid Build Coastguard Workerthe repo directory. 31*da0073e9SAndroid Build Coastguard Worker""" 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workerimport argparse 36*da0073e9SAndroid Build Coastguard Workerimport contextlib 37*da0073e9SAndroid Build Coastguard Workerimport functools 38*da0073e9SAndroid Build Coastguard Workerimport glob 39*da0073e9SAndroid Build Coastguard Workerimport itertools 40*da0073e9SAndroid Build Coastguard Workerimport json 41*da0073e9SAndroid Build Coastguard Workerimport logging 42*da0073e9SAndroid Build Coastguard Workerimport os 43*da0073e9SAndroid Build Coastguard Workerimport re 44*da0073e9SAndroid Build Coastguard Workerimport shutil 45*da0073e9SAndroid Build Coastguard Workerimport subprocess 46*da0073e9SAndroid Build Coastguard Workerimport sys 47*da0073e9SAndroid Build Coastguard Workerimport tempfile 48*da0073e9SAndroid Build Coastguard Workerimport time 49*da0073e9SAndroid Build Coastguard Workerimport uuid 50*da0073e9SAndroid Build Coastguard Workerfrom ast import literal_eval 51*da0073e9SAndroid Build Coastguard Workerfrom datetime import datetime 52*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path 53*da0073e9SAndroid Build Coastguard Workerfrom platform import system as platform_system 54*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard WorkerREPO_ROOT = Path(__file__).absolute().parent.parent 58*da0073e9SAndroid Build Coastguard WorkerGITHUB_REMOTE_URL = "https://github.com/pytorch/pytorch.git" 59*da0073e9SAndroid Build Coastguard WorkerSPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx") 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard WorkerLOGGER: logging.Logger | None = None 62*da0073e9SAndroid Build Coastguard WorkerURL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2" 63*da0073e9SAndroid Build Coastguard WorkerDATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss" 64*da0073e9SAndroid Build Coastguard WorkerSHA1_RE = re.compile(r"(?P<sha1>[0-9a-fA-F]{40})") 65*da0073e9SAndroid Build Coastguard WorkerUSERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@") 66*da0073e9SAndroid Build Coastguard WorkerLOG_DIRNAME_RE = re.compile( 67*da0073e9SAndroid Build Coastguard Worker r"(?P<datetime>\d{4}-\d\d-\d\d_\d\dh\d\dm\d\ds)_" 68*da0073e9SAndroid Build Coastguard Worker r"(?P<uuid>[0-9a-f]{8}-(?:[0-9a-f]{4}-){3}[0-9a-f]{12})", 69*da0073e9SAndroid Build Coastguard Worker) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Workerclass Formatter(logging.Formatter): 73*da0073e9SAndroid Build Coastguard Worker redactions: dict[str, str] 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def __init__(self, fmt: str | None = None, datefmt: str | None = None) -> None: 76*da0073e9SAndroid Build Coastguard Worker super().__init__(fmt, datefmt) 77*da0073e9SAndroid Build Coastguard Worker self.redactions = {} 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker # Remove sensitive information from URLs 80*da0073e9SAndroid Build Coastguard Worker def _filter(self, s: str) -> str: 81*da0073e9SAndroid Build Coastguard Worker s = USERNAME_PASSWORD_RE.sub(r"://<USERNAME>:<PASSWORD>@", s) 82*da0073e9SAndroid Build Coastguard Worker for needle, replace in self.redactions.items(): 83*da0073e9SAndroid Build Coastguard Worker s = s.replace(needle, replace) 84*da0073e9SAndroid Build Coastguard Worker return s 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def formatMessage(self, record: logging.LogRecord) -> str: 87*da0073e9SAndroid Build Coastguard Worker if record.levelno == logging.INFO or record.levelno == logging.DEBUG: 88*da0073e9SAndroid Build Coastguard Worker # Log INFO/DEBUG without any adornment 89*da0073e9SAndroid Build Coastguard Worker return record.getMessage() 90*da0073e9SAndroid Build Coastguard Worker else: 91*da0073e9SAndroid Build Coastguard Worker # I'm not sure why, but formatMessage doesn't show up 92*da0073e9SAndroid Build Coastguard Worker # even though it's in the typeshed for Python >3 93*da0073e9SAndroid Build Coastguard Worker return super().formatMessage(record) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def format(self, record: logging.LogRecord) -> str: 96*da0073e9SAndroid Build Coastguard Worker return self._filter(super().format(record)) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker def redact(self, needle: str, replace: str = "<REDACTED>") -> None: 99*da0073e9SAndroid Build Coastguard Worker """Redact specific strings; e.g., authorization tokens. This won't 100*da0073e9SAndroid Build Coastguard Worker retroactively redact stuff you've already leaked, so make sure 101*da0073e9SAndroid Build Coastguard Worker you redact things as soon as possible. 102*da0073e9SAndroid Build Coastguard Worker """ 103*da0073e9SAndroid Build Coastguard Worker # Don't redact empty strings; this will lead to something 104*da0073e9SAndroid Build Coastguard Worker # that looks like s<REDACTED>t<REDACTED>r<REDACTED>... 105*da0073e9SAndroid Build Coastguard Worker if needle == "": 106*da0073e9SAndroid Build Coastguard Worker return 107*da0073e9SAndroid Build Coastguard Worker self.redactions[needle] = replace 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Workerdef git(*args: str) -> list[str]: 111*da0073e9SAndroid Build Coastguard Worker return ["git", "-C", str(REPO_ROOT), *args] 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache 115*da0073e9SAndroid Build Coastguard Workerdef logging_base_dir() -> Path: 116*da0073e9SAndroid Build Coastguard Worker base_dir = REPO_ROOT / "nightly" / "log" 117*da0073e9SAndroid Build Coastguard Worker base_dir.mkdir(parents=True, exist_ok=True) 118*da0073e9SAndroid Build Coastguard Worker return base_dir 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache 122*da0073e9SAndroid Build Coastguard Workerdef logging_run_dir() -> Path: 123*da0073e9SAndroid Build Coastguard Worker base_dir = logging_base_dir() 124*da0073e9SAndroid Build Coastguard Worker cur_dir = base_dir / f"{datetime.now().strftime(DATETIME_FORMAT)}_{uuid.uuid1()}" 125*da0073e9SAndroid Build Coastguard Worker cur_dir.mkdir(parents=True, exist_ok=True) 126*da0073e9SAndroid Build Coastguard Worker return cur_dir 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache 130*da0073e9SAndroid Build Coastguard Workerdef logging_record_argv() -> None: 131*da0073e9SAndroid Build Coastguard Worker s = subprocess.list2cmdline(sys.argv) 132*da0073e9SAndroid Build Coastguard Worker (logging_run_dir() / "argv").write_text(s, encoding="utf-8") 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Workerdef logging_record_exception(e: BaseException) -> None: 136*da0073e9SAndroid Build Coastguard Worker (logging_run_dir() / "exception").write_text(type(e).__name__, encoding="utf-8") 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Workerdef logging_rotate() -> None: 140*da0073e9SAndroid Build Coastguard Worker log_base = logging_base_dir() 141*da0073e9SAndroid Build Coastguard Worker old_logs = sorted(log_base.iterdir(), reverse=True) 142*da0073e9SAndroid Build Coastguard Worker for stale_log in old_logs[1000:]: 143*da0073e9SAndroid Build Coastguard Worker # Sanity check that it looks like a log 144*da0073e9SAndroid Build Coastguard Worker if LOG_DIRNAME_RE.fullmatch(stale_log.name) is not None: 145*da0073e9SAndroid Build Coastguard Worker shutil.rmtree(stale_log) 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 149*da0073e9SAndroid Build Coastguard Workerdef logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, None]: 150*da0073e9SAndroid Build Coastguard Worker """Setup logging. If a failure starts here we won't 151*da0073e9SAndroid Build Coastguard Worker be able to save the user in a reasonable way. 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker Logging structure: there is one logger (the root logger) 154*da0073e9SAndroid Build Coastguard Worker and in processes all events. There are two handlers: 155*da0073e9SAndroid Build Coastguard Worker stderr (INFO) and file handler (DEBUG). 156*da0073e9SAndroid Build Coastguard Worker """ 157*da0073e9SAndroid Build Coastguard Worker formatter = Formatter(fmt="%(levelname)s: %(message)s", datefmt="") 158*da0073e9SAndroid Build Coastguard Worker root_logger = logging.getLogger("conda-pytorch") 159*da0073e9SAndroid Build Coastguard Worker root_logger.setLevel(logging.DEBUG) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker console_handler = logging.StreamHandler() 162*da0073e9SAndroid Build Coastguard Worker if debug: 163*da0073e9SAndroid Build Coastguard Worker console_handler.setLevel(logging.DEBUG) 164*da0073e9SAndroid Build Coastguard Worker else: 165*da0073e9SAndroid Build Coastguard Worker console_handler.setLevel(logging.INFO) 166*da0073e9SAndroid Build Coastguard Worker console_handler.setFormatter(formatter) 167*da0073e9SAndroid Build Coastguard Worker root_logger.addHandler(console_handler) 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker log_file = logging_run_dir() / "nightly.log" 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker file_handler = logging.FileHandler(log_file) 172*da0073e9SAndroid Build Coastguard Worker file_handler.setFormatter(formatter) 173*da0073e9SAndroid Build Coastguard Worker root_logger.addHandler(file_handler) 174*da0073e9SAndroid Build Coastguard Worker logging_record_argv() 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker try: 177*da0073e9SAndroid Build Coastguard Worker logging_rotate() 178*da0073e9SAndroid Build Coastguard Worker print(f"log file: {log_file}") 179*da0073e9SAndroid Build Coastguard Worker yield root_logger 180*da0073e9SAndroid Build Coastguard Worker except Exception as e: 181*da0073e9SAndroid Build Coastguard Worker logging.exception("Fatal exception") 182*da0073e9SAndroid Build Coastguard Worker logging_record_exception(e) 183*da0073e9SAndroid Build Coastguard Worker print(f"log file: {log_file}") 184*da0073e9SAndroid Build Coastguard Worker sys.exit(1) 185*da0073e9SAndroid Build Coastguard Worker except BaseException as e: 186*da0073e9SAndroid Build Coastguard Worker # You could logging.debug here to suppress the backtrace 187*da0073e9SAndroid Build Coastguard Worker # entirely, but there is no reason to hide it from technically 188*da0073e9SAndroid Build Coastguard Worker # savvy users. 189*da0073e9SAndroid Build Coastguard Worker logging.info("", exc_info=True) 190*da0073e9SAndroid Build Coastguard Worker logging_record_exception(e) 191*da0073e9SAndroid Build Coastguard Worker print(f"log file: {log_file}") 192*da0073e9SAndroid Build Coastguard Worker sys.exit(1) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Workerdef check_branch(subcommand: str, branch: str | None) -> str | None: 196*da0073e9SAndroid Build Coastguard Worker """Checks that the branch name can be checked out.""" 197*da0073e9SAndroid Build Coastguard Worker if subcommand != "checkout": 198*da0073e9SAndroid Build Coastguard Worker return None 199*da0073e9SAndroid Build Coastguard Worker # first make sure actual branch name was given 200*da0073e9SAndroid Build Coastguard Worker if branch is None: 201*da0073e9SAndroid Build Coastguard Worker return "Branch name to checkout must be supplied with '-b' option" 202*da0073e9SAndroid Build Coastguard Worker # next check that the local repo is clean 203*da0073e9SAndroid Build Coastguard Worker cmd = git("status", "--untracked-files=no", "--porcelain") 204*da0073e9SAndroid Build Coastguard Worker stdout = subprocess.check_output(cmd, text=True, encoding="utf-8") 205*da0073e9SAndroid Build Coastguard Worker if stdout.strip(): 206*da0073e9SAndroid Build Coastguard Worker return "Need to have clean working tree to checkout!\n\n" + stdout 207*da0073e9SAndroid Build Coastguard Worker # next check that the branch name doesn't already exist 208*da0073e9SAndroid Build Coastguard Worker cmd = git("show-ref", "--verify", "--quiet", f"refs/heads/{branch}") 209*da0073e9SAndroid Build Coastguard Worker p = subprocess.run(cmd, capture_output=True, check=False) # type: ignore[assignment] 210*da0073e9SAndroid Build Coastguard Worker if not p.returncode: 211*da0073e9SAndroid Build Coastguard Worker return f"Branch {branch!r} already exists" 212*da0073e9SAndroid Build Coastguard Worker return None 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 216*da0073e9SAndroid Build Coastguard Workerdef timer(logger: logging.Logger, prefix: str) -> Iterator[None]: 217*da0073e9SAndroid Build Coastguard Worker """Timed context manager""" 218*da0073e9SAndroid Build Coastguard Worker start_time = time.perf_counter() 219*da0073e9SAndroid Build Coastguard Worker yield 220*da0073e9SAndroid Build Coastguard Worker logger.info("%s took %.3f [s]", prefix, time.perf_counter() - start_time) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard WorkerF = TypeVar("F", bound=Callable[..., Any]) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Workerdef timed(prefix: str) -> Callable[[F], F]: 227*da0073e9SAndroid Build Coastguard Worker """Decorator for timing functions""" 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker def dec(f: F) -> F: 230*da0073e9SAndroid Build Coastguard Worker @functools.wraps(f) 231*da0073e9SAndroid Build Coastguard Worker def wrapper(*args: Any, **kwargs: Any) -> Any: 232*da0073e9SAndroid Build Coastguard Worker logger = cast(logging.Logger, LOGGER) 233*da0073e9SAndroid Build Coastguard Worker logger.info(prefix) 234*da0073e9SAndroid Build Coastguard Worker with timer(logger, prefix): 235*da0073e9SAndroid Build Coastguard Worker return f(*args, **kwargs) 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker return cast(F, wrapper) 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker return dec 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Workerdef _make_channel_args( 243*da0073e9SAndroid Build Coastguard Worker channels: Iterable[str] = ("pytorch-nightly",), 244*da0073e9SAndroid Build Coastguard Worker override_channels: bool = False, 245*da0073e9SAndroid Build Coastguard Worker) -> list[str]: 246*da0073e9SAndroid Build Coastguard Worker args = [] 247*da0073e9SAndroid Build Coastguard Worker for channel in channels: 248*da0073e9SAndroid Build Coastguard Worker args.extend(["--channel", channel]) 249*da0073e9SAndroid Build Coastguard Worker if override_channels: 250*da0073e9SAndroid Build Coastguard Worker args.append("--override-channels") 251*da0073e9SAndroid Build Coastguard Worker return args 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker@timed("Solving conda environment") 255*da0073e9SAndroid Build Coastguard Workerdef conda_solve( 256*da0073e9SAndroid Build Coastguard Worker specs: Iterable[str], 257*da0073e9SAndroid Build Coastguard Worker *, 258*da0073e9SAndroid Build Coastguard Worker name: str | None = None, 259*da0073e9SAndroid Build Coastguard Worker prefix: str | None = None, 260*da0073e9SAndroid Build Coastguard Worker channels: Iterable[str] = ("pytorch-nightly",), 261*da0073e9SAndroid Build Coastguard Worker override_channels: bool = False, 262*da0073e9SAndroid Build Coastguard Worker) -> tuple[list[str], str, str, bool, list[str]]: 263*da0073e9SAndroid Build Coastguard Worker """Performs the conda solve and splits the deps from the package.""" 264*da0073e9SAndroid Build Coastguard Worker # compute what environment to use 265*da0073e9SAndroid Build Coastguard Worker if prefix is not None: 266*da0073e9SAndroid Build Coastguard Worker existing_env = True 267*da0073e9SAndroid Build Coastguard Worker env_opts = ["--prefix", prefix] 268*da0073e9SAndroid Build Coastguard Worker elif name is not None: 269*da0073e9SAndroid Build Coastguard Worker existing_env = True 270*da0073e9SAndroid Build Coastguard Worker env_opts = ["--name", name] 271*da0073e9SAndroid Build Coastguard Worker else: 272*da0073e9SAndroid Build Coastguard Worker # create new environment 273*da0073e9SAndroid Build Coastguard Worker existing_env = False 274*da0073e9SAndroid Build Coastguard Worker env_opts = ["--name", "pytorch-deps"] 275*da0073e9SAndroid Build Coastguard Worker # run solve 276*da0073e9SAndroid Build Coastguard Worker if existing_env: 277*da0073e9SAndroid Build Coastguard Worker cmd = [ 278*da0073e9SAndroid Build Coastguard Worker "conda", 279*da0073e9SAndroid Build Coastguard Worker "install", 280*da0073e9SAndroid Build Coastguard Worker "--yes", 281*da0073e9SAndroid Build Coastguard Worker "--dry-run", 282*da0073e9SAndroid Build Coastguard Worker "--json", 283*da0073e9SAndroid Build Coastguard Worker ] 284*da0073e9SAndroid Build Coastguard Worker cmd.extend(env_opts) 285*da0073e9SAndroid Build Coastguard Worker else: 286*da0073e9SAndroid Build Coastguard Worker cmd = [ 287*da0073e9SAndroid Build Coastguard Worker "conda", 288*da0073e9SAndroid Build Coastguard Worker "create", 289*da0073e9SAndroid Build Coastguard Worker "--yes", 290*da0073e9SAndroid Build Coastguard Worker "--dry-run", 291*da0073e9SAndroid Build Coastguard Worker "--json", 292*da0073e9SAndroid Build Coastguard Worker "--name", 293*da0073e9SAndroid Build Coastguard Worker "__pytorch__", 294*da0073e9SAndroid Build Coastguard Worker ] 295*da0073e9SAndroid Build Coastguard Worker channel_args = _make_channel_args( 296*da0073e9SAndroid Build Coastguard Worker channels=channels, 297*da0073e9SAndroid Build Coastguard Worker override_channels=override_channels, 298*da0073e9SAndroid Build Coastguard Worker ) 299*da0073e9SAndroid Build Coastguard Worker cmd.extend(channel_args) 300*da0073e9SAndroid Build Coastguard Worker cmd.extend(specs) 301*da0073e9SAndroid Build Coastguard Worker stdout = subprocess.check_output(cmd, text=True, encoding="utf-8") 302*da0073e9SAndroid Build Coastguard Worker # parse solution 303*da0073e9SAndroid Build Coastguard Worker solve = json.loads(stdout) 304*da0073e9SAndroid Build Coastguard Worker link = solve["actions"]["LINK"] 305*da0073e9SAndroid Build Coastguard Worker deps = [] 306*da0073e9SAndroid Build Coastguard Worker pytorch, platform = "", "" 307*da0073e9SAndroid Build Coastguard Worker for pkg in link: 308*da0073e9SAndroid Build Coastguard Worker url = URL_FORMAT.format(**pkg) 309*da0073e9SAndroid Build Coastguard Worker if pkg["name"] == "pytorch": 310*da0073e9SAndroid Build Coastguard Worker pytorch = url 311*da0073e9SAndroid Build Coastguard Worker platform = pkg["platform"] 312*da0073e9SAndroid Build Coastguard Worker else: 313*da0073e9SAndroid Build Coastguard Worker deps.append(url) 314*da0073e9SAndroid Build Coastguard Worker assert pytorch, "PyTorch package not found in solve" 315*da0073e9SAndroid Build Coastguard Worker assert platform, "Platform not found in solve" 316*da0073e9SAndroid Build Coastguard Worker return deps, pytorch, platform, existing_env, env_opts 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker@timed("Installing dependencies") 320*da0073e9SAndroid Build Coastguard Workerdef deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None: 321*da0073e9SAndroid Build Coastguard Worker """Install dependencies to deps environment""" 322*da0073e9SAndroid Build Coastguard Worker if not existing_env: 323*da0073e9SAndroid Build Coastguard Worker # first remove previous pytorch-deps env 324*da0073e9SAndroid Build Coastguard Worker cmd = ["conda", "env", "remove", "--yes", *env_opts] 325*da0073e9SAndroid Build Coastguard Worker subprocess.check_call(cmd) 326*da0073e9SAndroid Build Coastguard Worker # install new deps 327*da0073e9SAndroid Build Coastguard Worker install_command = "install" if existing_env else "create" 328*da0073e9SAndroid Build Coastguard Worker cmd = ["conda", install_command, "--yes", "--no-deps", *env_opts, *deps] 329*da0073e9SAndroid Build Coastguard Worker subprocess.check_call(cmd) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker@timed("Installing pytorch nightly binaries") 333*da0073e9SAndroid Build Coastguard Workerdef pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]: 334*da0073e9SAndroid Build Coastguard Worker """Install pytorch into a temporary directory""" 335*da0073e9SAndroid Build Coastguard Worker pytorch_dir = tempfile.TemporaryDirectory(prefix="conda-pytorch-") 336*da0073e9SAndroid Build Coastguard Worker cmd = ["conda", "create", "--yes", "--no-deps", f"--prefix={pytorch_dir.name}", url] 337*da0073e9SAndroid Build Coastguard Worker subprocess.check_call(cmd) 338*da0073e9SAndroid Build Coastguard Worker return pytorch_dir 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Workerdef _site_packages(dirname: str, platform: str) -> Path: 342*da0073e9SAndroid Build Coastguard Worker if platform.startswith("win"): 343*da0073e9SAndroid Build Coastguard Worker template = os.path.join(dirname, "Lib", "site-packages") 344*da0073e9SAndroid Build Coastguard Worker else: 345*da0073e9SAndroid Build Coastguard Worker template = os.path.join(dirname, "lib", "python*.*", "site-packages") 346*da0073e9SAndroid Build Coastguard Worker return Path(next(glob.iglob(template))).absolute() 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Workerdef _ensure_commit(git_sha1: str) -> None: 350*da0073e9SAndroid Build Coastguard Worker """Make sure that we actually have the commit locally""" 351*da0073e9SAndroid Build Coastguard Worker cmd = git("cat-file", "-e", git_sha1 + r"^{commit}") 352*da0073e9SAndroid Build Coastguard Worker p = subprocess.run(cmd, capture_output=True, check=False) 353*da0073e9SAndroid Build Coastguard Worker if p.returncode == 0: 354*da0073e9SAndroid Build Coastguard Worker # we have the commit locally 355*da0073e9SAndroid Build Coastguard Worker return 356*da0073e9SAndroid Build Coastguard Worker # we don't have the commit, must fetch 357*da0073e9SAndroid Build Coastguard Worker cmd = git("fetch", GITHUB_REMOTE_URL, git_sha1) 358*da0073e9SAndroid Build Coastguard Worker subprocess.check_call(cmd) 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Workerdef _nightly_version(site_dir: Path) -> str: 362*da0073e9SAndroid Build Coastguard Worker # first get the git version from the installed module 363*da0073e9SAndroid Build Coastguard Worker version_file = site_dir / "torch" / "version.py" 364*da0073e9SAndroid Build Coastguard Worker with version_file.open(encoding="utf-8") as f: 365*da0073e9SAndroid Build Coastguard Worker for line in f: 366*da0073e9SAndroid Build Coastguard Worker if not line.startswith("git_version"): 367*da0073e9SAndroid Build Coastguard Worker continue 368*da0073e9SAndroid Build Coastguard Worker git_version = literal_eval(line.partition("=")[2].strip()) 369*da0073e9SAndroid Build Coastguard Worker break 370*da0073e9SAndroid Build Coastguard Worker else: 371*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Could not find git_version in {version_file}") 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker print(f"Found released git version {git_version}") 374*da0073e9SAndroid Build Coastguard Worker # now cross reference with nightly version 375*da0073e9SAndroid Build Coastguard Worker _ensure_commit(git_version) 376*da0073e9SAndroid Build Coastguard Worker cmd = git("show", "--no-patch", "--format=%s", git_version) 377*da0073e9SAndroid Build Coastguard Worker stdout = subprocess.check_output(cmd, text=True, encoding="utf-8") 378*da0073e9SAndroid Build Coastguard Worker m = SHA1_RE.search(stdout) 379*da0073e9SAndroid Build Coastguard Worker if m is None: 380*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 381*da0073e9SAndroid Build Coastguard Worker f"Could not find nightly release in git history:\n {stdout}" 382*da0073e9SAndroid Build Coastguard Worker ) 383*da0073e9SAndroid Build Coastguard Worker nightly_version = m.group("sha1") 384*da0073e9SAndroid Build Coastguard Worker print(f"Found nightly release version {nightly_version}") 385*da0073e9SAndroid Build Coastguard Worker # now checkout nightly version 386*da0073e9SAndroid Build Coastguard Worker _ensure_commit(nightly_version) 387*da0073e9SAndroid Build Coastguard Worker return nightly_version 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker@timed("Checking out nightly PyTorch") 391*da0073e9SAndroid Build Coastguard Workerdef checkout_nightly_version(branch: str, site_dir: Path) -> None: 392*da0073e9SAndroid Build Coastguard Worker """Get's the nightly version and then checks it out.""" 393*da0073e9SAndroid Build Coastguard Worker nightly_version = _nightly_version(site_dir) 394*da0073e9SAndroid Build Coastguard Worker cmd = git("checkout", "-b", branch, nightly_version) 395*da0073e9SAndroid Build Coastguard Worker subprocess.check_call(cmd) 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker@timed("Pulling nightly PyTorch") 399*da0073e9SAndroid Build Coastguard Workerdef pull_nightly_version(site_dir: Path) -> None: 400*da0073e9SAndroid Build Coastguard Worker """Fetches the nightly version and then merges it .""" 401*da0073e9SAndroid Build Coastguard Worker nightly_version = _nightly_version(site_dir) 402*da0073e9SAndroid Build Coastguard Worker cmd = git("merge", nightly_version) 403*da0073e9SAndroid Build Coastguard Worker subprocess.check_call(cmd) 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Workerdef _get_listing_linux(source_dir: Path) -> list[Path]: 407*da0073e9SAndroid Build Coastguard Worker return list( 408*da0073e9SAndroid Build Coastguard Worker itertools.chain( 409*da0073e9SAndroid Build Coastguard Worker source_dir.glob("*.so"), 410*da0073e9SAndroid Build Coastguard Worker (source_dir / "lib").glob("*.so"), 411*da0073e9SAndroid Build Coastguard Worker (source_dir / "lib").glob("*.so.*"), 412*da0073e9SAndroid Build Coastguard Worker ) 413*da0073e9SAndroid Build Coastguard Worker ) 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Workerdef _get_listing_osx(source_dir: Path) -> list[Path]: 417*da0073e9SAndroid Build Coastguard Worker # oddly, these are .so files even on Mac 418*da0073e9SAndroid Build Coastguard Worker return list( 419*da0073e9SAndroid Build Coastguard Worker itertools.chain( 420*da0073e9SAndroid Build Coastguard Worker source_dir.glob("*.so"), 421*da0073e9SAndroid Build Coastguard Worker (source_dir / "lib").glob("*.dylib"), 422*da0073e9SAndroid Build Coastguard Worker ) 423*da0073e9SAndroid Build Coastguard Worker ) 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Workerdef _get_listing_win(source_dir: Path) -> list[Path]: 427*da0073e9SAndroid Build Coastguard Worker return list( 428*da0073e9SAndroid Build Coastguard Worker itertools.chain( 429*da0073e9SAndroid Build Coastguard Worker source_dir.glob("*.pyd"), 430*da0073e9SAndroid Build Coastguard Worker (source_dir / "lib").glob("*.lib"), 431*da0073e9SAndroid Build Coastguard Worker (source_dir / "lib").glob(".dll"), 432*da0073e9SAndroid Build Coastguard Worker ) 433*da0073e9SAndroid Build Coastguard Worker ) 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Workerdef _glob_pyis(d: Path) -> set[str]: 437*da0073e9SAndroid Build Coastguard Worker return {p.relative_to(d).as_posix() for p in d.rglob("*.pyi")} 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Workerdef _find_missing_pyi(source_dir: Path, target_dir: Path) -> list[Path]: 441*da0073e9SAndroid Build Coastguard Worker source_pyis = _glob_pyis(source_dir) 442*da0073e9SAndroid Build Coastguard Worker target_pyis = _glob_pyis(target_dir) 443*da0073e9SAndroid Build Coastguard Worker missing_pyis = sorted(source_dir / p for p in (source_pyis - target_pyis)) 444*da0073e9SAndroid Build Coastguard Worker return missing_pyis 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Workerdef _get_listing(source_dir: Path, target_dir: Path, platform: str) -> list[Path]: 448*da0073e9SAndroid Build Coastguard Worker if platform.startswith("linux"): 449*da0073e9SAndroid Build Coastguard Worker listing = _get_listing_linux(source_dir) 450*da0073e9SAndroid Build Coastguard Worker elif platform.startswith("osx"): 451*da0073e9SAndroid Build Coastguard Worker listing = _get_listing_osx(source_dir) 452*da0073e9SAndroid Build Coastguard Worker elif platform.startswith("win"): 453*da0073e9SAndroid Build Coastguard Worker listing = _get_listing_win(source_dir) 454*da0073e9SAndroid Build Coastguard Worker else: 455*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Platform {platform!r} not recognized") 456*da0073e9SAndroid Build Coastguard Worker listing.extend(_find_missing_pyi(source_dir, target_dir)) 457*da0073e9SAndroid Build Coastguard Worker listing.append(source_dir / "version.py") 458*da0073e9SAndroid Build Coastguard Worker listing.append(source_dir / "testing" / "_internal" / "generated") 459*da0073e9SAndroid Build Coastguard Worker listing.append(source_dir / "bin") 460*da0073e9SAndroid Build Coastguard Worker listing.append(source_dir / "include") 461*da0073e9SAndroid Build Coastguard Worker return listing 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Workerdef _remove_existing(path: Path) -> None: 465*da0073e9SAndroid Build Coastguard Worker if path.exists(): 466*da0073e9SAndroid Build Coastguard Worker if path.is_dir(): 467*da0073e9SAndroid Build Coastguard Worker shutil.rmtree(path) 468*da0073e9SAndroid Build Coastguard Worker else: 469*da0073e9SAndroid Build Coastguard Worker path.unlink() 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Workerdef _move_single( 473*da0073e9SAndroid Build Coastguard Worker src: Path, 474*da0073e9SAndroid Build Coastguard Worker source_dir: Path, 475*da0073e9SAndroid Build Coastguard Worker target_dir: Path, 476*da0073e9SAndroid Build Coastguard Worker mover: Callable[[Path, Path], None], 477*da0073e9SAndroid Build Coastguard Worker verb: str, 478*da0073e9SAndroid Build Coastguard Worker) -> None: 479*da0073e9SAndroid Build Coastguard Worker relpath = src.relative_to(source_dir) 480*da0073e9SAndroid Build Coastguard Worker trg = target_dir / relpath 481*da0073e9SAndroid Build Coastguard Worker _remove_existing(trg) 482*da0073e9SAndroid Build Coastguard Worker # move over new files 483*da0073e9SAndroid Build Coastguard Worker if src.is_dir(): 484*da0073e9SAndroid Build Coastguard Worker trg.mkdir(parents=True, exist_ok=True) 485*da0073e9SAndroid Build Coastguard Worker for root, dirs, files in os.walk(src): 486*da0073e9SAndroid Build Coastguard Worker relroot = Path(root).relative_to(src) 487*da0073e9SAndroid Build Coastguard Worker for name in files: 488*da0073e9SAndroid Build Coastguard Worker relname = relroot / name 489*da0073e9SAndroid Build Coastguard Worker s = src / relname 490*da0073e9SAndroid Build Coastguard Worker t = trg / relname 491*da0073e9SAndroid Build Coastguard Worker print(f"{verb} {s} -> {t}") 492*da0073e9SAndroid Build Coastguard Worker mover(s, t) 493*da0073e9SAndroid Build Coastguard Worker for name in dirs: 494*da0073e9SAndroid Build Coastguard Worker (trg / relroot / name).mkdir(parents=True, exist_ok=True) 495*da0073e9SAndroid Build Coastguard Worker else: 496*da0073e9SAndroid Build Coastguard Worker print(f"{verb} {src} -> {trg}") 497*da0073e9SAndroid Build Coastguard Worker mover(src, trg) 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Workerdef _copy_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None: 501*da0073e9SAndroid Build Coastguard Worker for src in listing: 502*da0073e9SAndroid Build Coastguard Worker _move_single(src, source_dir, target_dir, shutil.copy2, "Copying") 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Workerdef _link_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None: 506*da0073e9SAndroid Build Coastguard Worker for src in listing: 507*da0073e9SAndroid Build Coastguard Worker _move_single(src, source_dir, target_dir, os.link, "Linking") 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker@timed("Moving nightly files into repo") 511*da0073e9SAndroid Build Coastguard Workerdef move_nightly_files(site_dir: Path, platform: str) -> None: 512*da0073e9SAndroid Build Coastguard Worker """Moves PyTorch files from temporary installed location to repo.""" 513*da0073e9SAndroid Build Coastguard Worker # get file listing 514*da0073e9SAndroid Build Coastguard Worker source_dir = site_dir / "torch" 515*da0073e9SAndroid Build Coastguard Worker target_dir = REPO_ROOT / "torch" 516*da0073e9SAndroid Build Coastguard Worker listing = _get_listing(source_dir, target_dir, platform) 517*da0073e9SAndroid Build Coastguard Worker # copy / link files 518*da0073e9SAndroid Build Coastguard Worker if platform.startswith("win"): 519*da0073e9SAndroid Build Coastguard Worker _copy_files(listing, source_dir, target_dir) 520*da0073e9SAndroid Build Coastguard Worker else: 521*da0073e9SAndroid Build Coastguard Worker try: 522*da0073e9SAndroid Build Coastguard Worker _link_files(listing, source_dir, target_dir) 523*da0073e9SAndroid Build Coastguard Worker except Exception: 524*da0073e9SAndroid Build Coastguard Worker _copy_files(listing, source_dir, target_dir) 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Workerdef _available_envs() -> dict[str, str]: 528*da0073e9SAndroid Build Coastguard Worker cmd = ["conda", "env", "list"] 529*da0073e9SAndroid Build Coastguard Worker stdout = subprocess.check_output(cmd, text=True, encoding="utf-8") 530*da0073e9SAndroid Build Coastguard Worker envs = {} 531*da0073e9SAndroid Build Coastguard Worker for line in map(str.strip, stdout.splitlines()): 532*da0073e9SAndroid Build Coastguard Worker if not line or line.startswith("#"): 533*da0073e9SAndroid Build Coastguard Worker continue 534*da0073e9SAndroid Build Coastguard Worker parts = line.split() 535*da0073e9SAndroid Build Coastguard Worker if len(parts) == 1: 536*da0073e9SAndroid Build Coastguard Worker # unnamed env 537*da0073e9SAndroid Build Coastguard Worker continue 538*da0073e9SAndroid Build Coastguard Worker envs[parts[0]] = parts[-1] 539*da0073e9SAndroid Build Coastguard Worker return envs 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Worker@timed("Writing pytorch-nightly.pth") 543*da0073e9SAndroid Build Coastguard Workerdef write_pth(env_opts: list[str], platform: str) -> None: 544*da0073e9SAndroid Build Coastguard Worker """Writes Python path file for this dir.""" 545*da0073e9SAndroid Build Coastguard Worker env_type, env_dir = env_opts 546*da0073e9SAndroid Build Coastguard Worker if env_type == "--name": 547*da0073e9SAndroid Build Coastguard Worker # have to find directory 548*da0073e9SAndroid Build Coastguard Worker envs = _available_envs() 549*da0073e9SAndroid Build Coastguard Worker env_dir = envs[env_dir] 550*da0073e9SAndroid Build Coastguard Worker site_dir = _site_packages(env_dir, platform) 551*da0073e9SAndroid Build Coastguard Worker (site_dir / "pytorch-nightly.pth").write_text( 552*da0073e9SAndroid Build Coastguard Worker "# This file was autogenerated by PyTorch's tools/nightly.py\n" 553*da0073e9SAndroid Build Coastguard Worker "# Please delete this file if you no longer need the following development\n" 554*da0073e9SAndroid Build Coastguard Worker "# version of PyTorch to be importable\n" 555*da0073e9SAndroid Build Coastguard Worker f"{REPO_ROOT}\n", 556*da0073e9SAndroid Build Coastguard Worker encoding="utf-8", 557*da0073e9SAndroid Build Coastguard Worker ) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Workerdef install( 561*da0073e9SAndroid Build Coastguard Worker specs: Iterable[str], 562*da0073e9SAndroid Build Coastguard Worker *, 563*da0073e9SAndroid Build Coastguard Worker logger: logging.Logger, 564*da0073e9SAndroid Build Coastguard Worker subcommand: str = "checkout", 565*da0073e9SAndroid Build Coastguard Worker branch: str | None = None, 566*da0073e9SAndroid Build Coastguard Worker name: str | None = None, 567*da0073e9SAndroid Build Coastguard Worker prefix: str | None = None, 568*da0073e9SAndroid Build Coastguard Worker channels: Iterable[str] = ("pytorch-nightly",), 569*da0073e9SAndroid Build Coastguard Worker override_channels: bool = False, 570*da0073e9SAndroid Build Coastguard Worker) -> None: 571*da0073e9SAndroid Build Coastguard Worker """Development install of PyTorch""" 572*da0073e9SAndroid Build Coastguard Worker specs = list(specs) 573*da0073e9SAndroid Build Coastguard Worker deps, pytorch, platform, existing_env, env_opts = conda_solve( 574*da0073e9SAndroid Build Coastguard Worker specs=specs, 575*da0073e9SAndroid Build Coastguard Worker name=name, 576*da0073e9SAndroid Build Coastguard Worker prefix=prefix, 577*da0073e9SAndroid Build Coastguard Worker channels=channels, 578*da0073e9SAndroid Build Coastguard Worker override_channels=override_channels, 579*da0073e9SAndroid Build Coastguard Worker ) 580*da0073e9SAndroid Build Coastguard Worker if deps: 581*da0073e9SAndroid Build Coastguard Worker deps_install(deps, existing_env, env_opts) 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker with pytorch_install(pytorch) as pytorch_dir: 584*da0073e9SAndroid Build Coastguard Worker site_dir = _site_packages(pytorch_dir, platform) 585*da0073e9SAndroid Build Coastguard Worker if subcommand == "checkout": 586*da0073e9SAndroid Build Coastguard Worker checkout_nightly_version(cast(str, branch), site_dir) 587*da0073e9SAndroid Build Coastguard Worker elif subcommand == "pull": 588*da0073e9SAndroid Build Coastguard Worker pull_nightly_version(site_dir) 589*da0073e9SAndroid Build Coastguard Worker else: 590*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Subcommand {subcommand} must be one of: checkout, pull.") 591*da0073e9SAndroid Build Coastguard Worker move_nightly_files(site_dir, platform) 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Worker write_pth(env_opts, platform) 594*da0073e9SAndroid Build Coastguard Worker logger.info( 595*da0073e9SAndroid Build Coastguard Worker "-------\nPyTorch Development Environment set up!\nPlease activate to " 596*da0073e9SAndroid Build Coastguard Worker "enable this environment:\n $ conda activate %s", 597*da0073e9SAndroid Build Coastguard Worker env_opts[1], 598*da0073e9SAndroid Build Coastguard Worker ) 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Workerdef make_parser() -> argparse.ArgumentParser: 602*da0073e9SAndroid Build Coastguard Worker p = argparse.ArgumentParser() 603*da0073e9SAndroid Build Coastguard Worker # subcommands 604*da0073e9SAndroid Build Coastguard Worker subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute") 605*da0073e9SAndroid Build Coastguard Worker checkout = subcmd.add_parser("checkout", help="checkout a new branch") 606*da0073e9SAndroid Build Coastguard Worker checkout.add_argument( 607*da0073e9SAndroid Build Coastguard Worker "-b", 608*da0073e9SAndroid Build Coastguard Worker "--branch", 609*da0073e9SAndroid Build Coastguard Worker help="Branch name to checkout", 610*da0073e9SAndroid Build Coastguard Worker dest="branch", 611*da0073e9SAndroid Build Coastguard Worker default=None, 612*da0073e9SAndroid Build Coastguard Worker metavar="NAME", 613*da0073e9SAndroid Build Coastguard Worker ) 614*da0073e9SAndroid Build Coastguard Worker pull = subcmd.add_parser( 615*da0073e9SAndroid Build Coastguard Worker "pull", help="pulls the nightly commits into the current branch" 616*da0073e9SAndroid Build Coastguard Worker ) 617*da0073e9SAndroid Build Coastguard Worker # general arguments 618*da0073e9SAndroid Build Coastguard Worker subparsers = [checkout, pull] 619*da0073e9SAndroid Build Coastguard Worker for subparser in subparsers: 620*da0073e9SAndroid Build Coastguard Worker subparser.add_argument( 621*da0073e9SAndroid Build Coastguard Worker "-n", 622*da0073e9SAndroid Build Coastguard Worker "--name", 623*da0073e9SAndroid Build Coastguard Worker help="Name of environment", 624*da0073e9SAndroid Build Coastguard Worker dest="name", 625*da0073e9SAndroid Build Coastguard Worker default=None, 626*da0073e9SAndroid Build Coastguard Worker metavar="ENVIRONMENT", 627*da0073e9SAndroid Build Coastguard Worker ) 628*da0073e9SAndroid Build Coastguard Worker subparser.add_argument( 629*da0073e9SAndroid Build Coastguard Worker "-p", 630*da0073e9SAndroid Build Coastguard Worker "--prefix", 631*da0073e9SAndroid Build Coastguard Worker help="Full path to environment location (i.e. prefix)", 632*da0073e9SAndroid Build Coastguard Worker dest="prefix", 633*da0073e9SAndroid Build Coastguard Worker default=None, 634*da0073e9SAndroid Build Coastguard Worker metavar="PATH", 635*da0073e9SAndroid Build Coastguard Worker ) 636*da0073e9SAndroid Build Coastguard Worker subparser.add_argument( 637*da0073e9SAndroid Build Coastguard Worker "-v", 638*da0073e9SAndroid Build Coastguard Worker "--verbose", 639*da0073e9SAndroid Build Coastguard Worker help="Provide debugging info", 640*da0073e9SAndroid Build Coastguard Worker dest="verbose", 641*da0073e9SAndroid Build Coastguard Worker default=False, 642*da0073e9SAndroid Build Coastguard Worker action="store_true", 643*da0073e9SAndroid Build Coastguard Worker ) 644*da0073e9SAndroid Build Coastguard Worker subparser.add_argument( 645*da0073e9SAndroid Build Coastguard Worker "--override-channels", 646*da0073e9SAndroid Build Coastguard Worker help="Do not search default or .condarc channels.", 647*da0073e9SAndroid Build Coastguard Worker dest="override_channels", 648*da0073e9SAndroid Build Coastguard Worker default=False, 649*da0073e9SAndroid Build Coastguard Worker action="store_true", 650*da0073e9SAndroid Build Coastguard Worker ) 651*da0073e9SAndroid Build Coastguard Worker subparser.add_argument( 652*da0073e9SAndroid Build Coastguard Worker "-c", 653*da0073e9SAndroid Build Coastguard Worker "--channel", 654*da0073e9SAndroid Build Coastguard Worker help=( 655*da0073e9SAndroid Build Coastguard Worker "Additional channel to search for packages. " 656*da0073e9SAndroid Build Coastguard Worker "'pytorch-nightly' will always be prepended to this list." 657*da0073e9SAndroid Build Coastguard Worker ), 658*da0073e9SAndroid Build Coastguard Worker dest="channels", 659*da0073e9SAndroid Build Coastguard Worker action="append", 660*da0073e9SAndroid Build Coastguard Worker metavar="CHANNEL", 661*da0073e9SAndroid Build Coastguard Worker ) 662*da0073e9SAndroid Build Coastguard Worker if platform_system() in {"Linux", "Windows"}: 663*da0073e9SAndroid Build Coastguard Worker subparser.add_argument( 664*da0073e9SAndroid Build Coastguard Worker "--cuda", 665*da0073e9SAndroid Build Coastguard Worker help=( 666*da0073e9SAndroid Build Coastguard Worker "CUDA version to install " 667*da0073e9SAndroid Build Coastguard Worker "(defaults to the latest version available on the platform)" 668*da0073e9SAndroid Build Coastguard Worker ), 669*da0073e9SAndroid Build Coastguard Worker dest="cuda", 670*da0073e9SAndroid Build Coastguard Worker nargs="?", 671*da0073e9SAndroid Build Coastguard Worker default=argparse.SUPPRESS, 672*da0073e9SAndroid Build Coastguard Worker metavar="VERSION", 673*da0073e9SAndroid Build Coastguard Worker ) 674*da0073e9SAndroid Build Coastguard Worker return p 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Workerdef main(args: Sequence[str] | None = None) -> None: 678*da0073e9SAndroid Build Coastguard Worker """Main entry point""" 679*da0073e9SAndroid Build Coastguard Worker global LOGGER 680*da0073e9SAndroid Build Coastguard Worker p = make_parser() 681*da0073e9SAndroid Build Coastguard Worker ns = p.parse_args(args) 682*da0073e9SAndroid Build Coastguard Worker ns.branch = getattr(ns, "branch", None) 683*da0073e9SAndroid Build Coastguard Worker status = check_branch(ns.subcmd, ns.branch) 684*da0073e9SAndroid Build Coastguard Worker if status: 685*da0073e9SAndroid Build Coastguard Worker sys.exit(status) 686*da0073e9SAndroid Build Coastguard Worker specs = list(SPECS_TO_INSTALL) 687*da0073e9SAndroid Build Coastguard Worker channels = ["pytorch-nightly"] 688*da0073e9SAndroid Build Coastguard Worker if hasattr(ns, "cuda"): 689*da0073e9SAndroid Build Coastguard Worker if ns.cuda is not None: 690*da0073e9SAndroid Build Coastguard Worker specs.append(f"pytorch-cuda={ns.cuda}") 691*da0073e9SAndroid Build Coastguard Worker else: 692*da0073e9SAndroid Build Coastguard Worker specs.append("pytorch-cuda") 693*da0073e9SAndroid Build Coastguard Worker specs.append("pytorch-mutex=*=*cuda*") 694*da0073e9SAndroid Build Coastguard Worker channels.append("nvidia") 695*da0073e9SAndroid Build Coastguard Worker else: 696*da0073e9SAndroid Build Coastguard Worker specs.append("pytorch-mutex=*=*cpu*") 697*da0073e9SAndroid Build Coastguard Worker if ns.channels: 698*da0073e9SAndroid Build Coastguard Worker channels.extend(ns.channels) 699*da0073e9SAndroid Build Coastguard Worker with logging_manager(debug=ns.verbose) as logger: 700*da0073e9SAndroid Build Coastguard Worker LOGGER = logger 701*da0073e9SAndroid Build Coastguard Worker install( 702*da0073e9SAndroid Build Coastguard Worker specs=specs, 703*da0073e9SAndroid Build Coastguard Worker subcommand=ns.subcmd, 704*da0073e9SAndroid Build Coastguard Worker branch=ns.branch, 705*da0073e9SAndroid Build Coastguard Worker name=ns.name, 706*da0073e9SAndroid Build Coastguard Worker prefix=ns.prefix, 707*da0073e9SAndroid Build Coastguard Worker logger=logger, 708*da0073e9SAndroid Build Coastguard Worker channels=channels, 709*da0073e9SAndroid Build Coastguard Worker override_channels=ns.override_channels, 710*da0073e9SAndroid Build Coastguard Worker ) 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 714*da0073e9SAndroid Build Coastguard Worker main() 715