xref: /aosp_15_r20/external/pytorch/tools/nightly.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker# Much of the logging code here was forked from https://github.com/ezyang/ghstack
3*da0073e9SAndroid Build Coastguard Worker# Copyright (c) Edward Z. Yang <[email protected]>
4*da0073e9SAndroid Build Coastguard Worker"""Checks out the nightly development version of PyTorch and installs pre-built
5*da0073e9SAndroid Build Coastguard Workerbinaries into the repo.
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard WorkerYou can use this script to check out a new nightly branch with the following::
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker    $ ./tools/nightly.py checkout -b my-nightly-branch
10*da0073e9SAndroid Build Coastguard Worker    $ conda activate pytorch-deps
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard WorkerOr if you would like to re-use an existing conda environment, you can pass in
13*da0073e9SAndroid Build Coastguard Workerthe regular environment parameters (--name or --prefix)::
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker    $ ./tools/nightly.py checkout -b my-nightly-branch -n my-env
16*da0073e9SAndroid Build Coastguard Worker    $ conda activate my-env
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard WorkerTo install the nightly binaries built with CUDA, you can pass in the flag --cuda::
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker    $ ./tools/nightly.py checkout -b my-nightly-branch --cuda
21*da0073e9SAndroid Build Coastguard Worker    $ conda activate pytorch-deps
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard WorkerYou can also use this tool to pull the nightly commits into the current branch as
24*da0073e9SAndroid Build Coastguard Workerwell. This can be done with::
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker    $ ./tools/nightly.py pull -n my-env
27*da0073e9SAndroid Build Coastguard Worker    $ conda activate my-env
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard WorkerPulling will reinstall the conda dependencies as well as the nightly binaries into
30*da0073e9SAndroid Build Coastguard Workerthe repo directory.
31*da0073e9SAndroid Build Coastguard Worker"""
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workerimport argparse
36*da0073e9SAndroid Build Coastguard Workerimport contextlib
37*da0073e9SAndroid Build Coastguard Workerimport functools
38*da0073e9SAndroid Build Coastguard Workerimport glob
39*da0073e9SAndroid Build Coastguard Workerimport itertools
40*da0073e9SAndroid Build Coastguard Workerimport json
41*da0073e9SAndroid Build Coastguard Workerimport logging
42*da0073e9SAndroid Build Coastguard Workerimport os
43*da0073e9SAndroid Build Coastguard Workerimport re
44*da0073e9SAndroid Build Coastguard Workerimport shutil
45*da0073e9SAndroid Build Coastguard Workerimport subprocess
46*da0073e9SAndroid Build Coastguard Workerimport sys
47*da0073e9SAndroid Build Coastguard Workerimport tempfile
48*da0073e9SAndroid Build Coastguard Workerimport time
49*da0073e9SAndroid Build Coastguard Workerimport uuid
50*da0073e9SAndroid Build Coastguard Workerfrom ast import literal_eval
51*da0073e9SAndroid Build Coastguard Workerfrom datetime import datetime
52*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
53*da0073e9SAndroid Build Coastguard Workerfrom platform import system as platform_system
54*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard WorkerREPO_ROOT = Path(__file__).absolute().parent.parent
58*da0073e9SAndroid Build Coastguard WorkerGITHUB_REMOTE_URL = "https://github.com/pytorch/pytorch.git"
59*da0073e9SAndroid Build Coastguard WorkerSPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx")
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard WorkerLOGGER: logging.Logger | None = None
62*da0073e9SAndroid Build Coastguard WorkerURL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
63*da0073e9SAndroid Build Coastguard WorkerDATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
64*da0073e9SAndroid Build Coastguard WorkerSHA1_RE = re.compile(r"(?P<sha1>[0-9a-fA-F]{40})")
65*da0073e9SAndroid Build Coastguard WorkerUSERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@")
66*da0073e9SAndroid Build Coastguard WorkerLOG_DIRNAME_RE = re.compile(
67*da0073e9SAndroid Build Coastguard Worker    r"(?P<datetime>\d{4}-\d\d-\d\d_\d\dh\d\dm\d\ds)_"
68*da0073e9SAndroid Build Coastguard Worker    r"(?P<uuid>[0-9a-f]{8}-(?:[0-9a-f]{4}-){3}[0-9a-f]{12})",
69*da0073e9SAndroid Build Coastguard Worker)
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Workerclass Formatter(logging.Formatter):
73*da0073e9SAndroid Build Coastguard Worker    redactions: dict[str, str]
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    def __init__(self, fmt: str | None = None, datefmt: str | None = None) -> None:
76*da0073e9SAndroid Build Coastguard Worker        super().__init__(fmt, datefmt)
77*da0073e9SAndroid Build Coastguard Worker        self.redactions = {}
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    # Remove sensitive information from URLs
80*da0073e9SAndroid Build Coastguard Worker    def _filter(self, s: str) -> str:
81*da0073e9SAndroid Build Coastguard Worker        s = USERNAME_PASSWORD_RE.sub(r"://<USERNAME>:<PASSWORD>@", s)
82*da0073e9SAndroid Build Coastguard Worker        for needle, replace in self.redactions.items():
83*da0073e9SAndroid Build Coastguard Worker            s = s.replace(needle, replace)
84*da0073e9SAndroid Build Coastguard Worker        return s
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def formatMessage(self, record: logging.LogRecord) -> str:
87*da0073e9SAndroid Build Coastguard Worker        if record.levelno == logging.INFO or record.levelno == logging.DEBUG:
88*da0073e9SAndroid Build Coastguard Worker            # Log INFO/DEBUG without any adornment
89*da0073e9SAndroid Build Coastguard Worker            return record.getMessage()
90*da0073e9SAndroid Build Coastguard Worker        else:
91*da0073e9SAndroid Build Coastguard Worker            # I'm not sure why, but formatMessage doesn't show up
92*da0073e9SAndroid Build Coastguard Worker            # even though it's in the typeshed for Python >3
93*da0073e9SAndroid Build Coastguard Worker            return super().formatMessage(record)
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    def format(self, record: logging.LogRecord) -> str:
96*da0073e9SAndroid Build Coastguard Worker        return self._filter(super().format(record))
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    def redact(self, needle: str, replace: str = "<REDACTED>") -> None:
99*da0073e9SAndroid Build Coastguard Worker        """Redact specific strings; e.g., authorization tokens.  This won't
100*da0073e9SAndroid Build Coastguard Worker        retroactively redact stuff you've already leaked, so make sure
101*da0073e9SAndroid Build Coastguard Worker        you redact things as soon as possible.
102*da0073e9SAndroid Build Coastguard Worker        """
103*da0073e9SAndroid Build Coastguard Worker        # Don't redact empty strings; this will lead to something
104*da0073e9SAndroid Build Coastguard Worker        # that looks like s<REDACTED>t<REDACTED>r<REDACTED>...
105*da0073e9SAndroid Build Coastguard Worker        if needle == "":
106*da0073e9SAndroid Build Coastguard Worker            return
107*da0073e9SAndroid Build Coastguard Worker        self.redactions[needle] = replace
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Workerdef git(*args: str) -> list[str]:
111*da0073e9SAndroid Build Coastguard Worker    return ["git", "-C", str(REPO_ROOT), *args]
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache
115*da0073e9SAndroid Build Coastguard Workerdef logging_base_dir() -> Path:
116*da0073e9SAndroid Build Coastguard Worker    base_dir = REPO_ROOT / "nightly" / "log"
117*da0073e9SAndroid Build Coastguard Worker    base_dir.mkdir(parents=True, exist_ok=True)
118*da0073e9SAndroid Build Coastguard Worker    return base_dir
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache
122*da0073e9SAndroid Build Coastguard Workerdef logging_run_dir() -> Path:
123*da0073e9SAndroid Build Coastguard Worker    base_dir = logging_base_dir()
124*da0073e9SAndroid Build Coastguard Worker    cur_dir = base_dir / f"{datetime.now().strftime(DATETIME_FORMAT)}_{uuid.uuid1()}"
125*da0073e9SAndroid Build Coastguard Worker    cur_dir.mkdir(parents=True, exist_ok=True)
126*da0073e9SAndroid Build Coastguard Worker    return cur_dir
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache
130*da0073e9SAndroid Build Coastguard Workerdef logging_record_argv() -> None:
131*da0073e9SAndroid Build Coastguard Worker    s = subprocess.list2cmdline(sys.argv)
132*da0073e9SAndroid Build Coastguard Worker    (logging_run_dir() / "argv").write_text(s, encoding="utf-8")
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Workerdef logging_record_exception(e: BaseException) -> None:
136*da0073e9SAndroid Build Coastguard Worker    (logging_run_dir() / "exception").write_text(type(e).__name__, encoding="utf-8")
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Workerdef logging_rotate() -> None:
140*da0073e9SAndroid Build Coastguard Worker    log_base = logging_base_dir()
141*da0073e9SAndroid Build Coastguard Worker    old_logs = sorted(log_base.iterdir(), reverse=True)
142*da0073e9SAndroid Build Coastguard Worker    for stale_log in old_logs[1000:]:
143*da0073e9SAndroid Build Coastguard Worker        # Sanity check that it looks like a log
144*da0073e9SAndroid Build Coastguard Worker        if LOG_DIRNAME_RE.fullmatch(stale_log.name) is not None:
145*da0073e9SAndroid Build Coastguard Worker            shutil.rmtree(stale_log)
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
149*da0073e9SAndroid Build Coastguard Workerdef logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, None]:
150*da0073e9SAndroid Build Coastguard Worker    """Setup logging. If a failure starts here we won't
151*da0073e9SAndroid Build Coastguard Worker    be able to save the user in a reasonable way.
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    Logging structure: there is one logger (the root logger)
154*da0073e9SAndroid Build Coastguard Worker    and in processes all events.  There are two handlers:
155*da0073e9SAndroid Build Coastguard Worker    stderr (INFO) and file handler (DEBUG).
156*da0073e9SAndroid Build Coastguard Worker    """
157*da0073e9SAndroid Build Coastguard Worker    formatter = Formatter(fmt="%(levelname)s: %(message)s", datefmt="")
158*da0073e9SAndroid Build Coastguard Worker    root_logger = logging.getLogger("conda-pytorch")
159*da0073e9SAndroid Build Coastguard Worker    root_logger.setLevel(logging.DEBUG)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    console_handler = logging.StreamHandler()
162*da0073e9SAndroid Build Coastguard Worker    if debug:
163*da0073e9SAndroid Build Coastguard Worker        console_handler.setLevel(logging.DEBUG)
164*da0073e9SAndroid Build Coastguard Worker    else:
165*da0073e9SAndroid Build Coastguard Worker        console_handler.setLevel(logging.INFO)
166*da0073e9SAndroid Build Coastguard Worker    console_handler.setFormatter(formatter)
167*da0073e9SAndroid Build Coastguard Worker    root_logger.addHandler(console_handler)
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    log_file = logging_run_dir() / "nightly.log"
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker    file_handler = logging.FileHandler(log_file)
172*da0073e9SAndroid Build Coastguard Worker    file_handler.setFormatter(formatter)
173*da0073e9SAndroid Build Coastguard Worker    root_logger.addHandler(file_handler)
174*da0073e9SAndroid Build Coastguard Worker    logging_record_argv()
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker    try:
177*da0073e9SAndroid Build Coastguard Worker        logging_rotate()
178*da0073e9SAndroid Build Coastguard Worker        print(f"log file: {log_file}")
179*da0073e9SAndroid Build Coastguard Worker        yield root_logger
180*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
181*da0073e9SAndroid Build Coastguard Worker        logging.exception("Fatal exception")
182*da0073e9SAndroid Build Coastguard Worker        logging_record_exception(e)
183*da0073e9SAndroid Build Coastguard Worker        print(f"log file: {log_file}")
184*da0073e9SAndroid Build Coastguard Worker        sys.exit(1)
185*da0073e9SAndroid Build Coastguard Worker    except BaseException as e:
186*da0073e9SAndroid Build Coastguard Worker        # You could logging.debug here to suppress the backtrace
187*da0073e9SAndroid Build Coastguard Worker        # entirely, but there is no reason to hide it from technically
188*da0073e9SAndroid Build Coastguard Worker        # savvy users.
189*da0073e9SAndroid Build Coastguard Worker        logging.info("", exc_info=True)
190*da0073e9SAndroid Build Coastguard Worker        logging_record_exception(e)
191*da0073e9SAndroid Build Coastguard Worker        print(f"log file: {log_file}")
192*da0073e9SAndroid Build Coastguard Worker        sys.exit(1)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Workerdef check_branch(subcommand: str, branch: str | None) -> str | None:
196*da0073e9SAndroid Build Coastguard Worker    """Checks that the branch name can be checked out."""
197*da0073e9SAndroid Build Coastguard Worker    if subcommand != "checkout":
198*da0073e9SAndroid Build Coastguard Worker        return None
199*da0073e9SAndroid Build Coastguard Worker    # first make sure actual branch name was given
200*da0073e9SAndroid Build Coastguard Worker    if branch is None:
201*da0073e9SAndroid Build Coastguard Worker        return "Branch name to checkout must be supplied with '-b' option"
202*da0073e9SAndroid Build Coastguard Worker    # next check that the local repo is clean
203*da0073e9SAndroid Build Coastguard Worker    cmd = git("status", "--untracked-files=no", "--porcelain")
204*da0073e9SAndroid Build Coastguard Worker    stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
205*da0073e9SAndroid Build Coastguard Worker    if stdout.strip():
206*da0073e9SAndroid Build Coastguard Worker        return "Need to have clean working tree to checkout!\n\n" + stdout
207*da0073e9SAndroid Build Coastguard Worker    # next check that the branch name doesn't already exist
208*da0073e9SAndroid Build Coastguard Worker    cmd = git("show-ref", "--verify", "--quiet", f"refs/heads/{branch}")
209*da0073e9SAndroid Build Coastguard Worker    p = subprocess.run(cmd, capture_output=True, check=False)  # type: ignore[assignment]
210*da0073e9SAndroid Build Coastguard Worker    if not p.returncode:
211*da0073e9SAndroid Build Coastguard Worker        return f"Branch {branch!r} already exists"
212*da0073e9SAndroid Build Coastguard Worker    return None
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
216*da0073e9SAndroid Build Coastguard Workerdef timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
217*da0073e9SAndroid Build Coastguard Worker    """Timed context manager"""
218*da0073e9SAndroid Build Coastguard Worker    start_time = time.perf_counter()
219*da0073e9SAndroid Build Coastguard Worker    yield
220*da0073e9SAndroid Build Coastguard Worker    logger.info("%s took %.3f [s]", prefix, time.perf_counter() - start_time)
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard WorkerF = TypeVar("F", bound=Callable[..., Any])
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Workerdef timed(prefix: str) -> Callable[[F], F]:
227*da0073e9SAndroid Build Coastguard Worker    """Decorator for timing functions"""
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker    def dec(f: F) -> F:
230*da0073e9SAndroid Build Coastguard Worker        @functools.wraps(f)
231*da0073e9SAndroid Build Coastguard Worker        def wrapper(*args: Any, **kwargs: Any) -> Any:
232*da0073e9SAndroid Build Coastguard Worker            logger = cast(logging.Logger, LOGGER)
233*da0073e9SAndroid Build Coastguard Worker            logger.info(prefix)
234*da0073e9SAndroid Build Coastguard Worker            with timer(logger, prefix):
235*da0073e9SAndroid Build Coastguard Worker                return f(*args, **kwargs)
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker        return cast(F, wrapper)
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker    return dec
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Workerdef _make_channel_args(
243*da0073e9SAndroid Build Coastguard Worker    channels: Iterable[str] = ("pytorch-nightly",),
244*da0073e9SAndroid Build Coastguard Worker    override_channels: bool = False,
245*da0073e9SAndroid Build Coastguard Worker) -> list[str]:
246*da0073e9SAndroid Build Coastguard Worker    args = []
247*da0073e9SAndroid Build Coastguard Worker    for channel in channels:
248*da0073e9SAndroid Build Coastguard Worker        args.extend(["--channel", channel])
249*da0073e9SAndroid Build Coastguard Worker    if override_channels:
250*da0073e9SAndroid Build Coastguard Worker        args.append("--override-channels")
251*da0073e9SAndroid Build Coastguard Worker    return args
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker@timed("Solving conda environment")
255*da0073e9SAndroid Build Coastguard Workerdef conda_solve(
256*da0073e9SAndroid Build Coastguard Worker    specs: Iterable[str],
257*da0073e9SAndroid Build Coastguard Worker    *,
258*da0073e9SAndroid Build Coastguard Worker    name: str | None = None,
259*da0073e9SAndroid Build Coastguard Worker    prefix: str | None = None,
260*da0073e9SAndroid Build Coastguard Worker    channels: Iterable[str] = ("pytorch-nightly",),
261*da0073e9SAndroid Build Coastguard Worker    override_channels: bool = False,
262*da0073e9SAndroid Build Coastguard Worker) -> tuple[list[str], str, str, bool, list[str]]:
263*da0073e9SAndroid Build Coastguard Worker    """Performs the conda solve and splits the deps from the package."""
264*da0073e9SAndroid Build Coastguard Worker    # compute what environment to use
265*da0073e9SAndroid Build Coastguard Worker    if prefix is not None:
266*da0073e9SAndroid Build Coastguard Worker        existing_env = True
267*da0073e9SAndroid Build Coastguard Worker        env_opts = ["--prefix", prefix]
268*da0073e9SAndroid Build Coastguard Worker    elif name is not None:
269*da0073e9SAndroid Build Coastguard Worker        existing_env = True
270*da0073e9SAndroid Build Coastguard Worker        env_opts = ["--name", name]
271*da0073e9SAndroid Build Coastguard Worker    else:
272*da0073e9SAndroid Build Coastguard Worker        # create new environment
273*da0073e9SAndroid Build Coastguard Worker        existing_env = False
274*da0073e9SAndroid Build Coastguard Worker        env_opts = ["--name", "pytorch-deps"]
275*da0073e9SAndroid Build Coastguard Worker    # run solve
276*da0073e9SAndroid Build Coastguard Worker    if existing_env:
277*da0073e9SAndroid Build Coastguard Worker        cmd = [
278*da0073e9SAndroid Build Coastguard Worker            "conda",
279*da0073e9SAndroid Build Coastguard Worker            "install",
280*da0073e9SAndroid Build Coastguard Worker            "--yes",
281*da0073e9SAndroid Build Coastguard Worker            "--dry-run",
282*da0073e9SAndroid Build Coastguard Worker            "--json",
283*da0073e9SAndroid Build Coastguard Worker        ]
284*da0073e9SAndroid Build Coastguard Worker        cmd.extend(env_opts)
285*da0073e9SAndroid Build Coastguard Worker    else:
286*da0073e9SAndroid Build Coastguard Worker        cmd = [
287*da0073e9SAndroid Build Coastguard Worker            "conda",
288*da0073e9SAndroid Build Coastguard Worker            "create",
289*da0073e9SAndroid Build Coastguard Worker            "--yes",
290*da0073e9SAndroid Build Coastguard Worker            "--dry-run",
291*da0073e9SAndroid Build Coastguard Worker            "--json",
292*da0073e9SAndroid Build Coastguard Worker            "--name",
293*da0073e9SAndroid Build Coastguard Worker            "__pytorch__",
294*da0073e9SAndroid Build Coastguard Worker        ]
295*da0073e9SAndroid Build Coastguard Worker    channel_args = _make_channel_args(
296*da0073e9SAndroid Build Coastguard Worker        channels=channels,
297*da0073e9SAndroid Build Coastguard Worker        override_channels=override_channels,
298*da0073e9SAndroid Build Coastguard Worker    )
299*da0073e9SAndroid Build Coastguard Worker    cmd.extend(channel_args)
300*da0073e9SAndroid Build Coastguard Worker    cmd.extend(specs)
301*da0073e9SAndroid Build Coastguard Worker    stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
302*da0073e9SAndroid Build Coastguard Worker    # parse solution
303*da0073e9SAndroid Build Coastguard Worker    solve = json.loads(stdout)
304*da0073e9SAndroid Build Coastguard Worker    link = solve["actions"]["LINK"]
305*da0073e9SAndroid Build Coastguard Worker    deps = []
306*da0073e9SAndroid Build Coastguard Worker    pytorch, platform = "", ""
307*da0073e9SAndroid Build Coastguard Worker    for pkg in link:
308*da0073e9SAndroid Build Coastguard Worker        url = URL_FORMAT.format(**pkg)
309*da0073e9SAndroid Build Coastguard Worker        if pkg["name"] == "pytorch":
310*da0073e9SAndroid Build Coastguard Worker            pytorch = url
311*da0073e9SAndroid Build Coastguard Worker            platform = pkg["platform"]
312*da0073e9SAndroid Build Coastguard Worker        else:
313*da0073e9SAndroid Build Coastguard Worker            deps.append(url)
314*da0073e9SAndroid Build Coastguard Worker    assert pytorch, "PyTorch package not found in solve"
315*da0073e9SAndroid Build Coastguard Worker    assert platform, "Platform not found in solve"
316*da0073e9SAndroid Build Coastguard Worker    return deps, pytorch, platform, existing_env, env_opts
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker@timed("Installing dependencies")
320*da0073e9SAndroid Build Coastguard Workerdef deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None:
321*da0073e9SAndroid Build Coastguard Worker    """Install dependencies to deps environment"""
322*da0073e9SAndroid Build Coastguard Worker    if not existing_env:
323*da0073e9SAndroid Build Coastguard Worker        # first remove previous pytorch-deps env
324*da0073e9SAndroid Build Coastguard Worker        cmd = ["conda", "env", "remove", "--yes", *env_opts]
325*da0073e9SAndroid Build Coastguard Worker        subprocess.check_call(cmd)
326*da0073e9SAndroid Build Coastguard Worker    # install new deps
327*da0073e9SAndroid Build Coastguard Worker    install_command = "install" if existing_env else "create"
328*da0073e9SAndroid Build Coastguard Worker    cmd = ["conda", install_command, "--yes", "--no-deps", *env_opts, *deps]
329*da0073e9SAndroid Build Coastguard Worker    subprocess.check_call(cmd)
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker@timed("Installing pytorch nightly binaries")
333*da0073e9SAndroid Build Coastguard Workerdef pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
334*da0073e9SAndroid Build Coastguard Worker    """Install pytorch into a temporary directory"""
335*da0073e9SAndroid Build Coastguard Worker    pytorch_dir = tempfile.TemporaryDirectory(prefix="conda-pytorch-")
336*da0073e9SAndroid Build Coastguard Worker    cmd = ["conda", "create", "--yes", "--no-deps", f"--prefix={pytorch_dir.name}", url]
337*da0073e9SAndroid Build Coastguard Worker    subprocess.check_call(cmd)
338*da0073e9SAndroid Build Coastguard Worker    return pytorch_dir
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Workerdef _site_packages(dirname: str, platform: str) -> Path:
342*da0073e9SAndroid Build Coastguard Worker    if platform.startswith("win"):
343*da0073e9SAndroid Build Coastguard Worker        template = os.path.join(dirname, "Lib", "site-packages")
344*da0073e9SAndroid Build Coastguard Worker    else:
345*da0073e9SAndroid Build Coastguard Worker        template = os.path.join(dirname, "lib", "python*.*", "site-packages")
346*da0073e9SAndroid Build Coastguard Worker    return Path(next(glob.iglob(template))).absolute()
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Workerdef _ensure_commit(git_sha1: str) -> None:
350*da0073e9SAndroid Build Coastguard Worker    """Make sure that we actually have the commit locally"""
351*da0073e9SAndroid Build Coastguard Worker    cmd = git("cat-file", "-e", git_sha1 + r"^{commit}")
352*da0073e9SAndroid Build Coastguard Worker    p = subprocess.run(cmd, capture_output=True, check=False)
353*da0073e9SAndroid Build Coastguard Worker    if p.returncode == 0:
354*da0073e9SAndroid Build Coastguard Worker        # we have the commit locally
355*da0073e9SAndroid Build Coastguard Worker        return
356*da0073e9SAndroid Build Coastguard Worker    # we don't have the commit, must fetch
357*da0073e9SAndroid Build Coastguard Worker    cmd = git("fetch", GITHUB_REMOTE_URL, git_sha1)
358*da0073e9SAndroid Build Coastguard Worker    subprocess.check_call(cmd)
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Workerdef _nightly_version(site_dir: Path) -> str:
362*da0073e9SAndroid Build Coastguard Worker    # first get the git version from the installed module
363*da0073e9SAndroid Build Coastguard Worker    version_file = site_dir / "torch" / "version.py"
364*da0073e9SAndroid Build Coastguard Worker    with version_file.open(encoding="utf-8") as f:
365*da0073e9SAndroid Build Coastguard Worker        for line in f:
366*da0073e9SAndroid Build Coastguard Worker            if not line.startswith("git_version"):
367*da0073e9SAndroid Build Coastguard Worker                continue
368*da0073e9SAndroid Build Coastguard Worker            git_version = literal_eval(line.partition("=")[2].strip())
369*da0073e9SAndroid Build Coastguard Worker            break
370*da0073e9SAndroid Build Coastguard Worker        else:
371*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"Could not find git_version in {version_file}")
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    print(f"Found released git version {git_version}")
374*da0073e9SAndroid Build Coastguard Worker    # now cross reference with nightly version
375*da0073e9SAndroid Build Coastguard Worker    _ensure_commit(git_version)
376*da0073e9SAndroid Build Coastguard Worker    cmd = git("show", "--no-patch", "--format=%s", git_version)
377*da0073e9SAndroid Build Coastguard Worker    stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
378*da0073e9SAndroid Build Coastguard Worker    m = SHA1_RE.search(stdout)
379*da0073e9SAndroid Build Coastguard Worker    if m is None:
380*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
381*da0073e9SAndroid Build Coastguard Worker            f"Could not find nightly release in git history:\n  {stdout}"
382*da0073e9SAndroid Build Coastguard Worker        )
383*da0073e9SAndroid Build Coastguard Worker    nightly_version = m.group("sha1")
384*da0073e9SAndroid Build Coastguard Worker    print(f"Found nightly release version {nightly_version}")
385*da0073e9SAndroid Build Coastguard Worker    # now checkout nightly version
386*da0073e9SAndroid Build Coastguard Worker    _ensure_commit(nightly_version)
387*da0073e9SAndroid Build Coastguard Worker    return nightly_version
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker@timed("Checking out nightly PyTorch")
391*da0073e9SAndroid Build Coastguard Workerdef checkout_nightly_version(branch: str, site_dir: Path) -> None:
392*da0073e9SAndroid Build Coastguard Worker    """Get's the nightly version and then checks it out."""
393*da0073e9SAndroid Build Coastguard Worker    nightly_version = _nightly_version(site_dir)
394*da0073e9SAndroid Build Coastguard Worker    cmd = git("checkout", "-b", branch, nightly_version)
395*da0073e9SAndroid Build Coastguard Worker    subprocess.check_call(cmd)
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker@timed("Pulling nightly PyTorch")
399*da0073e9SAndroid Build Coastguard Workerdef pull_nightly_version(site_dir: Path) -> None:
400*da0073e9SAndroid Build Coastguard Worker    """Fetches the nightly version and then merges it ."""
401*da0073e9SAndroid Build Coastguard Worker    nightly_version = _nightly_version(site_dir)
402*da0073e9SAndroid Build Coastguard Worker    cmd = git("merge", nightly_version)
403*da0073e9SAndroid Build Coastguard Worker    subprocess.check_call(cmd)
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Workerdef _get_listing_linux(source_dir: Path) -> list[Path]:
407*da0073e9SAndroid Build Coastguard Worker    return list(
408*da0073e9SAndroid Build Coastguard Worker        itertools.chain(
409*da0073e9SAndroid Build Coastguard Worker            source_dir.glob("*.so"),
410*da0073e9SAndroid Build Coastguard Worker            (source_dir / "lib").glob("*.so"),
411*da0073e9SAndroid Build Coastguard Worker            (source_dir / "lib").glob("*.so.*"),
412*da0073e9SAndroid Build Coastguard Worker        )
413*da0073e9SAndroid Build Coastguard Worker    )
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Workerdef _get_listing_osx(source_dir: Path) -> list[Path]:
417*da0073e9SAndroid Build Coastguard Worker    # oddly, these are .so files even on Mac
418*da0073e9SAndroid Build Coastguard Worker    return list(
419*da0073e9SAndroid Build Coastguard Worker        itertools.chain(
420*da0073e9SAndroid Build Coastguard Worker            source_dir.glob("*.so"),
421*da0073e9SAndroid Build Coastguard Worker            (source_dir / "lib").glob("*.dylib"),
422*da0073e9SAndroid Build Coastguard Worker        )
423*da0073e9SAndroid Build Coastguard Worker    )
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Workerdef _get_listing_win(source_dir: Path) -> list[Path]:
427*da0073e9SAndroid Build Coastguard Worker    return list(
428*da0073e9SAndroid Build Coastguard Worker        itertools.chain(
429*da0073e9SAndroid Build Coastguard Worker            source_dir.glob("*.pyd"),
430*da0073e9SAndroid Build Coastguard Worker            (source_dir / "lib").glob("*.lib"),
431*da0073e9SAndroid Build Coastguard Worker            (source_dir / "lib").glob(".dll"),
432*da0073e9SAndroid Build Coastguard Worker        )
433*da0073e9SAndroid Build Coastguard Worker    )
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Workerdef _glob_pyis(d: Path) -> set[str]:
437*da0073e9SAndroid Build Coastguard Worker    return {p.relative_to(d).as_posix() for p in d.rglob("*.pyi")}
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Workerdef _find_missing_pyi(source_dir: Path, target_dir: Path) -> list[Path]:
441*da0073e9SAndroid Build Coastguard Worker    source_pyis = _glob_pyis(source_dir)
442*da0073e9SAndroid Build Coastguard Worker    target_pyis = _glob_pyis(target_dir)
443*da0073e9SAndroid Build Coastguard Worker    missing_pyis = sorted(source_dir / p for p in (source_pyis - target_pyis))
444*da0073e9SAndroid Build Coastguard Worker    return missing_pyis
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Workerdef _get_listing(source_dir: Path, target_dir: Path, platform: str) -> list[Path]:
448*da0073e9SAndroid Build Coastguard Worker    if platform.startswith("linux"):
449*da0073e9SAndroid Build Coastguard Worker        listing = _get_listing_linux(source_dir)
450*da0073e9SAndroid Build Coastguard Worker    elif platform.startswith("osx"):
451*da0073e9SAndroid Build Coastguard Worker        listing = _get_listing_osx(source_dir)
452*da0073e9SAndroid Build Coastguard Worker    elif platform.startswith("win"):
453*da0073e9SAndroid Build Coastguard Worker        listing = _get_listing_win(source_dir)
454*da0073e9SAndroid Build Coastguard Worker    else:
455*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"Platform {platform!r} not recognized")
456*da0073e9SAndroid Build Coastguard Worker    listing.extend(_find_missing_pyi(source_dir, target_dir))
457*da0073e9SAndroid Build Coastguard Worker    listing.append(source_dir / "version.py")
458*da0073e9SAndroid Build Coastguard Worker    listing.append(source_dir / "testing" / "_internal" / "generated")
459*da0073e9SAndroid Build Coastguard Worker    listing.append(source_dir / "bin")
460*da0073e9SAndroid Build Coastguard Worker    listing.append(source_dir / "include")
461*da0073e9SAndroid Build Coastguard Worker    return listing
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Workerdef _remove_existing(path: Path) -> None:
465*da0073e9SAndroid Build Coastguard Worker    if path.exists():
466*da0073e9SAndroid Build Coastguard Worker        if path.is_dir():
467*da0073e9SAndroid Build Coastguard Worker            shutil.rmtree(path)
468*da0073e9SAndroid Build Coastguard Worker        else:
469*da0073e9SAndroid Build Coastguard Worker            path.unlink()
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Workerdef _move_single(
473*da0073e9SAndroid Build Coastguard Worker    src: Path,
474*da0073e9SAndroid Build Coastguard Worker    source_dir: Path,
475*da0073e9SAndroid Build Coastguard Worker    target_dir: Path,
476*da0073e9SAndroid Build Coastguard Worker    mover: Callable[[Path, Path], None],
477*da0073e9SAndroid Build Coastguard Worker    verb: str,
478*da0073e9SAndroid Build Coastguard Worker) -> None:
479*da0073e9SAndroid Build Coastguard Worker    relpath = src.relative_to(source_dir)
480*da0073e9SAndroid Build Coastguard Worker    trg = target_dir / relpath
481*da0073e9SAndroid Build Coastguard Worker    _remove_existing(trg)
482*da0073e9SAndroid Build Coastguard Worker    # move over new files
483*da0073e9SAndroid Build Coastguard Worker    if src.is_dir():
484*da0073e9SAndroid Build Coastguard Worker        trg.mkdir(parents=True, exist_ok=True)
485*da0073e9SAndroid Build Coastguard Worker        for root, dirs, files in os.walk(src):
486*da0073e9SAndroid Build Coastguard Worker            relroot = Path(root).relative_to(src)
487*da0073e9SAndroid Build Coastguard Worker            for name in files:
488*da0073e9SAndroid Build Coastguard Worker                relname = relroot / name
489*da0073e9SAndroid Build Coastguard Worker                s = src / relname
490*da0073e9SAndroid Build Coastguard Worker                t = trg / relname
491*da0073e9SAndroid Build Coastguard Worker                print(f"{verb} {s} -> {t}")
492*da0073e9SAndroid Build Coastguard Worker                mover(s, t)
493*da0073e9SAndroid Build Coastguard Worker            for name in dirs:
494*da0073e9SAndroid Build Coastguard Worker                (trg / relroot / name).mkdir(parents=True, exist_ok=True)
495*da0073e9SAndroid Build Coastguard Worker    else:
496*da0073e9SAndroid Build Coastguard Worker        print(f"{verb} {src} -> {trg}")
497*da0073e9SAndroid Build Coastguard Worker        mover(src, trg)
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Workerdef _copy_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None:
501*da0073e9SAndroid Build Coastguard Worker    for src in listing:
502*da0073e9SAndroid Build Coastguard Worker        _move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Workerdef _link_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None:
506*da0073e9SAndroid Build Coastguard Worker    for src in listing:
507*da0073e9SAndroid Build Coastguard Worker        _move_single(src, source_dir, target_dir, os.link, "Linking")
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker@timed("Moving nightly files into repo")
511*da0073e9SAndroid Build Coastguard Workerdef move_nightly_files(site_dir: Path, platform: str) -> None:
512*da0073e9SAndroid Build Coastguard Worker    """Moves PyTorch files from temporary installed location to repo."""
513*da0073e9SAndroid Build Coastguard Worker    # get file listing
514*da0073e9SAndroid Build Coastguard Worker    source_dir = site_dir / "torch"
515*da0073e9SAndroid Build Coastguard Worker    target_dir = REPO_ROOT / "torch"
516*da0073e9SAndroid Build Coastguard Worker    listing = _get_listing(source_dir, target_dir, platform)
517*da0073e9SAndroid Build Coastguard Worker    # copy / link files
518*da0073e9SAndroid Build Coastguard Worker    if platform.startswith("win"):
519*da0073e9SAndroid Build Coastguard Worker        _copy_files(listing, source_dir, target_dir)
520*da0073e9SAndroid Build Coastguard Worker    else:
521*da0073e9SAndroid Build Coastguard Worker        try:
522*da0073e9SAndroid Build Coastguard Worker            _link_files(listing, source_dir, target_dir)
523*da0073e9SAndroid Build Coastguard Worker        except Exception:
524*da0073e9SAndroid Build Coastguard Worker            _copy_files(listing, source_dir, target_dir)
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Workerdef _available_envs() -> dict[str, str]:
528*da0073e9SAndroid Build Coastguard Worker    cmd = ["conda", "env", "list"]
529*da0073e9SAndroid Build Coastguard Worker    stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
530*da0073e9SAndroid Build Coastguard Worker    envs = {}
531*da0073e9SAndroid Build Coastguard Worker    for line in map(str.strip, stdout.splitlines()):
532*da0073e9SAndroid Build Coastguard Worker        if not line or line.startswith("#"):
533*da0073e9SAndroid Build Coastguard Worker            continue
534*da0073e9SAndroid Build Coastguard Worker        parts = line.split()
535*da0073e9SAndroid Build Coastguard Worker        if len(parts) == 1:
536*da0073e9SAndroid Build Coastguard Worker            # unnamed env
537*da0073e9SAndroid Build Coastguard Worker            continue
538*da0073e9SAndroid Build Coastguard Worker        envs[parts[0]] = parts[-1]
539*da0073e9SAndroid Build Coastguard Worker    return envs
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Worker@timed("Writing pytorch-nightly.pth")
543*da0073e9SAndroid Build Coastguard Workerdef write_pth(env_opts: list[str], platform: str) -> None:
544*da0073e9SAndroid Build Coastguard Worker    """Writes Python path file for this dir."""
545*da0073e9SAndroid Build Coastguard Worker    env_type, env_dir = env_opts
546*da0073e9SAndroid Build Coastguard Worker    if env_type == "--name":
547*da0073e9SAndroid Build Coastguard Worker        # have to find directory
548*da0073e9SAndroid Build Coastguard Worker        envs = _available_envs()
549*da0073e9SAndroid Build Coastguard Worker        env_dir = envs[env_dir]
550*da0073e9SAndroid Build Coastguard Worker    site_dir = _site_packages(env_dir, platform)
551*da0073e9SAndroid Build Coastguard Worker    (site_dir / "pytorch-nightly.pth").write_text(
552*da0073e9SAndroid Build Coastguard Worker        "# This file was autogenerated by PyTorch's tools/nightly.py\n"
553*da0073e9SAndroid Build Coastguard Worker        "# Please delete this file if you no longer need the following development\n"
554*da0073e9SAndroid Build Coastguard Worker        "# version of PyTorch to be importable\n"
555*da0073e9SAndroid Build Coastguard Worker        f"{REPO_ROOT}\n",
556*da0073e9SAndroid Build Coastguard Worker        encoding="utf-8",
557*da0073e9SAndroid Build Coastguard Worker    )
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Workerdef install(
561*da0073e9SAndroid Build Coastguard Worker    specs: Iterable[str],
562*da0073e9SAndroid Build Coastguard Worker    *,
563*da0073e9SAndroid Build Coastguard Worker    logger: logging.Logger,
564*da0073e9SAndroid Build Coastguard Worker    subcommand: str = "checkout",
565*da0073e9SAndroid Build Coastguard Worker    branch: str | None = None,
566*da0073e9SAndroid Build Coastguard Worker    name: str | None = None,
567*da0073e9SAndroid Build Coastguard Worker    prefix: str | None = None,
568*da0073e9SAndroid Build Coastguard Worker    channels: Iterable[str] = ("pytorch-nightly",),
569*da0073e9SAndroid Build Coastguard Worker    override_channels: bool = False,
570*da0073e9SAndroid Build Coastguard Worker) -> None:
571*da0073e9SAndroid Build Coastguard Worker    """Development install of PyTorch"""
572*da0073e9SAndroid Build Coastguard Worker    specs = list(specs)
573*da0073e9SAndroid Build Coastguard Worker    deps, pytorch, platform, existing_env, env_opts = conda_solve(
574*da0073e9SAndroid Build Coastguard Worker        specs=specs,
575*da0073e9SAndroid Build Coastguard Worker        name=name,
576*da0073e9SAndroid Build Coastguard Worker        prefix=prefix,
577*da0073e9SAndroid Build Coastguard Worker        channels=channels,
578*da0073e9SAndroid Build Coastguard Worker        override_channels=override_channels,
579*da0073e9SAndroid Build Coastguard Worker    )
580*da0073e9SAndroid Build Coastguard Worker    if deps:
581*da0073e9SAndroid Build Coastguard Worker        deps_install(deps, existing_env, env_opts)
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker    with pytorch_install(pytorch) as pytorch_dir:
584*da0073e9SAndroid Build Coastguard Worker        site_dir = _site_packages(pytorch_dir, platform)
585*da0073e9SAndroid Build Coastguard Worker        if subcommand == "checkout":
586*da0073e9SAndroid Build Coastguard Worker            checkout_nightly_version(cast(str, branch), site_dir)
587*da0073e9SAndroid Build Coastguard Worker        elif subcommand == "pull":
588*da0073e9SAndroid Build Coastguard Worker            pull_nightly_version(site_dir)
589*da0073e9SAndroid Build Coastguard Worker        else:
590*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Subcommand {subcommand} must be one of: checkout, pull.")
591*da0073e9SAndroid Build Coastguard Worker        move_nightly_files(site_dir, platform)
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard Worker    write_pth(env_opts, platform)
594*da0073e9SAndroid Build Coastguard Worker    logger.info(
595*da0073e9SAndroid Build Coastguard Worker        "-------\nPyTorch Development Environment set up!\nPlease activate to "
596*da0073e9SAndroid Build Coastguard Worker        "enable this environment:\n  $ conda activate %s",
597*da0073e9SAndroid Build Coastguard Worker        env_opts[1],
598*da0073e9SAndroid Build Coastguard Worker    )
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Workerdef make_parser() -> argparse.ArgumentParser:
602*da0073e9SAndroid Build Coastguard Worker    p = argparse.ArgumentParser()
603*da0073e9SAndroid Build Coastguard Worker    # subcommands
604*da0073e9SAndroid Build Coastguard Worker    subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
605*da0073e9SAndroid Build Coastguard Worker    checkout = subcmd.add_parser("checkout", help="checkout a new branch")
606*da0073e9SAndroid Build Coastguard Worker    checkout.add_argument(
607*da0073e9SAndroid Build Coastguard Worker        "-b",
608*da0073e9SAndroid Build Coastguard Worker        "--branch",
609*da0073e9SAndroid Build Coastguard Worker        help="Branch name to checkout",
610*da0073e9SAndroid Build Coastguard Worker        dest="branch",
611*da0073e9SAndroid Build Coastguard Worker        default=None,
612*da0073e9SAndroid Build Coastguard Worker        metavar="NAME",
613*da0073e9SAndroid Build Coastguard Worker    )
614*da0073e9SAndroid Build Coastguard Worker    pull = subcmd.add_parser(
615*da0073e9SAndroid Build Coastguard Worker        "pull", help="pulls the nightly commits into the current branch"
616*da0073e9SAndroid Build Coastguard Worker    )
617*da0073e9SAndroid Build Coastguard Worker    # general arguments
618*da0073e9SAndroid Build Coastguard Worker    subparsers = [checkout, pull]
619*da0073e9SAndroid Build Coastguard Worker    for subparser in subparsers:
620*da0073e9SAndroid Build Coastguard Worker        subparser.add_argument(
621*da0073e9SAndroid Build Coastguard Worker            "-n",
622*da0073e9SAndroid Build Coastguard Worker            "--name",
623*da0073e9SAndroid Build Coastguard Worker            help="Name of environment",
624*da0073e9SAndroid Build Coastguard Worker            dest="name",
625*da0073e9SAndroid Build Coastguard Worker            default=None,
626*da0073e9SAndroid Build Coastguard Worker            metavar="ENVIRONMENT",
627*da0073e9SAndroid Build Coastguard Worker        )
628*da0073e9SAndroid Build Coastguard Worker        subparser.add_argument(
629*da0073e9SAndroid Build Coastguard Worker            "-p",
630*da0073e9SAndroid Build Coastguard Worker            "--prefix",
631*da0073e9SAndroid Build Coastguard Worker            help="Full path to environment location (i.e. prefix)",
632*da0073e9SAndroid Build Coastguard Worker            dest="prefix",
633*da0073e9SAndroid Build Coastguard Worker            default=None,
634*da0073e9SAndroid Build Coastguard Worker            metavar="PATH",
635*da0073e9SAndroid Build Coastguard Worker        )
636*da0073e9SAndroid Build Coastguard Worker        subparser.add_argument(
637*da0073e9SAndroid Build Coastguard Worker            "-v",
638*da0073e9SAndroid Build Coastguard Worker            "--verbose",
639*da0073e9SAndroid Build Coastguard Worker            help="Provide debugging info",
640*da0073e9SAndroid Build Coastguard Worker            dest="verbose",
641*da0073e9SAndroid Build Coastguard Worker            default=False,
642*da0073e9SAndroid Build Coastguard Worker            action="store_true",
643*da0073e9SAndroid Build Coastguard Worker        )
644*da0073e9SAndroid Build Coastguard Worker        subparser.add_argument(
645*da0073e9SAndroid Build Coastguard Worker            "--override-channels",
646*da0073e9SAndroid Build Coastguard Worker            help="Do not search default or .condarc channels.",
647*da0073e9SAndroid Build Coastguard Worker            dest="override_channels",
648*da0073e9SAndroid Build Coastguard Worker            default=False,
649*da0073e9SAndroid Build Coastguard Worker            action="store_true",
650*da0073e9SAndroid Build Coastguard Worker        )
651*da0073e9SAndroid Build Coastguard Worker        subparser.add_argument(
652*da0073e9SAndroid Build Coastguard Worker            "-c",
653*da0073e9SAndroid Build Coastguard Worker            "--channel",
654*da0073e9SAndroid Build Coastguard Worker            help=(
655*da0073e9SAndroid Build Coastguard Worker                "Additional channel to search for packages. "
656*da0073e9SAndroid Build Coastguard Worker                "'pytorch-nightly' will always be prepended to this list."
657*da0073e9SAndroid Build Coastguard Worker            ),
658*da0073e9SAndroid Build Coastguard Worker            dest="channels",
659*da0073e9SAndroid Build Coastguard Worker            action="append",
660*da0073e9SAndroid Build Coastguard Worker            metavar="CHANNEL",
661*da0073e9SAndroid Build Coastguard Worker        )
662*da0073e9SAndroid Build Coastguard Worker        if platform_system() in {"Linux", "Windows"}:
663*da0073e9SAndroid Build Coastguard Worker            subparser.add_argument(
664*da0073e9SAndroid Build Coastguard Worker                "--cuda",
665*da0073e9SAndroid Build Coastguard Worker                help=(
666*da0073e9SAndroid Build Coastguard Worker                    "CUDA version to install "
667*da0073e9SAndroid Build Coastguard Worker                    "(defaults to the latest version available on the platform)"
668*da0073e9SAndroid Build Coastguard Worker                ),
669*da0073e9SAndroid Build Coastguard Worker                dest="cuda",
670*da0073e9SAndroid Build Coastguard Worker                nargs="?",
671*da0073e9SAndroid Build Coastguard Worker                default=argparse.SUPPRESS,
672*da0073e9SAndroid Build Coastguard Worker                metavar="VERSION",
673*da0073e9SAndroid Build Coastguard Worker            )
674*da0073e9SAndroid Build Coastguard Worker    return p
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Workerdef main(args: Sequence[str] | None = None) -> None:
678*da0073e9SAndroid Build Coastguard Worker    """Main entry point"""
679*da0073e9SAndroid Build Coastguard Worker    global LOGGER
680*da0073e9SAndroid Build Coastguard Worker    p = make_parser()
681*da0073e9SAndroid Build Coastguard Worker    ns = p.parse_args(args)
682*da0073e9SAndroid Build Coastguard Worker    ns.branch = getattr(ns, "branch", None)
683*da0073e9SAndroid Build Coastguard Worker    status = check_branch(ns.subcmd, ns.branch)
684*da0073e9SAndroid Build Coastguard Worker    if status:
685*da0073e9SAndroid Build Coastguard Worker        sys.exit(status)
686*da0073e9SAndroid Build Coastguard Worker    specs = list(SPECS_TO_INSTALL)
687*da0073e9SAndroid Build Coastguard Worker    channels = ["pytorch-nightly"]
688*da0073e9SAndroid Build Coastguard Worker    if hasattr(ns, "cuda"):
689*da0073e9SAndroid Build Coastguard Worker        if ns.cuda is not None:
690*da0073e9SAndroid Build Coastguard Worker            specs.append(f"pytorch-cuda={ns.cuda}")
691*da0073e9SAndroid Build Coastguard Worker        else:
692*da0073e9SAndroid Build Coastguard Worker            specs.append("pytorch-cuda")
693*da0073e9SAndroid Build Coastguard Worker        specs.append("pytorch-mutex=*=*cuda*")
694*da0073e9SAndroid Build Coastguard Worker        channels.append("nvidia")
695*da0073e9SAndroid Build Coastguard Worker    else:
696*da0073e9SAndroid Build Coastguard Worker        specs.append("pytorch-mutex=*=*cpu*")
697*da0073e9SAndroid Build Coastguard Worker    if ns.channels:
698*da0073e9SAndroid Build Coastguard Worker        channels.extend(ns.channels)
699*da0073e9SAndroid Build Coastguard Worker    with logging_manager(debug=ns.verbose) as logger:
700*da0073e9SAndroid Build Coastguard Worker        LOGGER = logger
701*da0073e9SAndroid Build Coastguard Worker        install(
702*da0073e9SAndroid Build Coastguard Worker            specs=specs,
703*da0073e9SAndroid Build Coastguard Worker            subcommand=ns.subcmd,
704*da0073e9SAndroid Build Coastguard Worker            branch=ns.branch,
705*da0073e9SAndroid Build Coastguard Worker            name=ns.name,
706*da0073e9SAndroid Build Coastguard Worker            prefix=ns.prefix,
707*da0073e9SAndroid Build Coastguard Worker            logger=logger,
708*da0073e9SAndroid Build Coastguard Worker            channels=channels,
709*da0073e9SAndroid Build Coastguard Worker            override_channels=ns.override_channels,
710*da0073e9SAndroid Build Coastguard Worker        )
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
714*da0073e9SAndroid Build Coastguard Worker    main()
715