xref: /aosp_15_r20/external/toolchain-utils/llvm_tools/get_patch.py (revision 760c253c1ed00ce9abd48f8546f08516e57485fe)
1#!/usr/bin/env python3
2# Copyright 2024 The ChromiumOS Authors
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6"""Get patches from a patch source, and integrate them into ChromiumOS.
7
8Example Usage:
9    # Apply a Pull request.
10    $ get_patch.py -s HEAD p:74791
11    # Apply several patches.
12    $ get_patch.py -s 82e851a407c5 p:74791 47413bb27
13    # Use another llvm-project dir.
14    $ get_patch.py -s HEAD -l ~/llvm-project 47413bb27
15"""
16
17import argparse
18import dataclasses
19import json
20import logging
21from pathlib import Path
22import random
23import re
24import subprocess
25import tempfile
26import textwrap
27from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
28from urllib import request
29
30import atomic_write_file
31import git_llvm_rev
32import patch_utils
33
34
35CHROMIUMOS_OVERLAY_PATH = Path("src/third_party/chromiumos-overlay")
36LLVM_PKG_PATH = CHROMIUMOS_OVERLAY_PATH / "sys-devel/llvm"
37COMPILER_RT_PKG_PATH = CHROMIUMOS_OVERLAY_PATH / "sys-libs/compiler-rt"
38LIBCXX_PKG_PATH = CHROMIUMOS_OVERLAY_PATH / "sys-libs/libcxx"
39LIBUNWIND_PKG_PATH = CHROMIUMOS_OVERLAY_PATH / "sys-libs/llvm-libunwind"
40SCUDO_PKG_PATH = CHROMIUMOS_OVERLAY_PATH / "sys-libs/scudo"
41LLDB_PKG_PATH = CHROMIUMOS_OVERLAY_PATH / "dev-util/lldb-server"
42
43LLVM_PROJECT_PATH = Path("src/third_party/llvm-project")
44PATCH_METADATA_FILENAME = "PATCHES.json"
45
46
47class CherrypickError(ValueError):
48    """ValueError for a cherry-pick has been seen before."""
49
50
51class CherrypickVersionError(ValueError):
52    """ValueError that highlights the cherry-pick is before the start_ref."""
53
54
55@dataclasses.dataclass
56class LLVMGitRef:
57    """Represents an LLVM git ref."""
58
59    git_ref: str
60    _rev: Optional[git_llvm_rev.Rev] = None  # Used for caching
61
62    @classmethod
63    def from_rev(cls, llvm_dir: Path, rev: git_llvm_rev.Rev) -> "LLVMGitRef":
64        return cls(
65            git_llvm_rev.translate_rev_to_sha(
66                git_llvm_rev.LLVMConfig("origin", llvm_dir), rev
67            ),
68            _rev=rev,
69        )
70
71    def to_rev(self, llvm_dir: Path) -> git_llvm_rev.Rev:
72        if self._rev:
73            return self._rev
74        self._rev = git_llvm_rev.translate_sha_to_rev(
75            git_llvm_rev.LLVMConfig("origin", llvm_dir),
76            self.git_ref,
77        )
78        return self._rev
79
80
81@dataclasses.dataclass(frozen=True)
82class LLVMPullRequest:
83    """Represents an upstream GitHub Pull Request number."""
84
85    number: int
86
87
88@dataclasses.dataclass
89class PatchContext:
90    """Represents the state of the chromiumos source during patching."""
91
92    llvm_project_dir: Path
93    chromiumos_root: Path
94    start_ref: LLVMGitRef
95    platforms: Iterable[str]
96    dry_run: bool = False
97
98    def apply_patches(
99        self, patch_source: Union[LLVMGitRef, LLVMPullRequest]
100    ) -> None:
101        """Create .patch files and add them to PATCHES.json.
102
103        Post:
104            Unless self.dry_run is True, writes the patch contents to
105            the respective <pkg>/files/ workdir for each applicable
106            patch, and the JSON files are updated with the new entries.
107
108        Raises:
109            TypeError: If the patch_source is not a
110                LLVMGitRef or LLVMPullRequest.
111        """
112        new_patch_entries = self.make_patches(patch_source)
113        self.apply_entries_to_json(new_patch_entries)
114
115    def apply_entries_to_json(
116        self,
117        new_patch_entries: Iterable[patch_utils.PatchEntry],
118    ) -> None:
119        """Add some PatchEntries to the appropriate PATCHES.json."""
120        workdir_mappings: Dict[Path, List[patch_utils.PatchEntry]] = {}
121        for pe in new_patch_entries:
122            workdir_mappings[pe.workdir] = workdir_mappings.get(
123                pe.workdir, []
124            ) + [pe]
125        for workdir, pes in workdir_mappings.items():
126            patches_json_file = workdir / PATCH_METADATA_FILENAME
127            with patches_json_file.open(encoding="utf-8") as f:
128                orig_contents = f.read()
129            old_patch_entries = patch_utils.json_str_to_patch_entries(
130                workdir, orig_contents
131            )
132            indent_len = patch_utils.predict_indent(orig_contents.splitlines())
133            if not self.dry_run:
134                with atomic_write_file.atomic_write(
135                    patches_json_file, encoding="utf-8"
136                ) as f:
137                    json.dump(
138                        [pe.to_dict() for pe in old_patch_entries + pes],
139                        f,
140                        indent=indent_len,
141                    )
142                    f.write("\n")
143
144    def make_patches(
145        self, patch_source: Union[LLVMGitRef, LLVMPullRequest]
146    ) -> List[patch_utils.PatchEntry]:
147        """Create PatchEntries for a given LLVM change and returns them.
148
149        Returns:
150            A list of PatchEntries representing the patches for each
151            package for the given patch_source.
152
153        Post:
154            Unless self.dry_run is True, writes the patch contents to
155            the respective <pkg>/files/ workdir for each applicable
156            patch.
157
158        Raises:
159            TypeError: If the patch_source is not a
160                LLVMGitRef or LLVMPullRequest.
161        """
162
163        # This is just a dispatch method to the actual methods.
164        if isinstance(patch_source, LLVMGitRef):
165            return self._make_patches_from_git_ref(patch_source)
166        if isinstance(patch_source, LLVMPullRequest):
167            return self._make_patches_from_pr(patch_source)
168        raise TypeError(
169            f"patch_source was invalid type {type(patch_source).__name__}"
170        )
171
172    def _make_patches_from_git_ref(
173        self,
174        patch_source: LLVMGitRef,
175    ) -> List[patch_utils.PatchEntry]:
176        packages = get_changed_packages(
177            self.llvm_project_dir, patch_source.git_ref
178        )
179        new_patch_entries: List[patch_utils.PatchEntry] = []
180        for workdir in self._workdirs_for_packages(packages):
181            rel_patch_path = f"cherry/{patch_source.git_ref}.patch"
182            if (workdir / "cherry").is_dir():
183                rel_patch_path = f"cherry/{patch_source.git_ref}.patch"
184            else:
185                # Some packages don't have a cherry directory.
186                rel_patch_path = f"{patch_source.git_ref}.patch"
187            if not self._is_valid_patch_range(self.start_ref, patch_source):
188                raise CherrypickVersionError(
189                    f"'from' ref {self.start_ref} is later or"
190                    f" same as than 'until' ref {patch_source}"
191                )
192            pe = patch_utils.PatchEntry(
193                workdir=workdir,
194                metadata={
195                    "title": get_commit_subj(
196                        self.llvm_project_dir, patch_source.git_ref
197                    ),
198                    "info": [],
199                },
200                platforms=list(self.platforms),
201                rel_patch_path=rel_patch_path,
202                version_range={
203                    "from": self.start_ref.to_rev(self.llvm_project_dir).number,
204                    "until": patch_source.to_rev(self.llvm_project_dir).number,
205                },
206            )
207            # Before we actually do any modifications, check if the patch is
208            # already applied.
209            if self.is_patch_applied(pe):
210                raise CherrypickError(
211                    f"Patch at {pe.rel_patch_path}"
212                    " already exists in PATCHES.json"
213                )
214            contents = git_format_patch(
215                self.llvm_project_dir,
216                patch_source.git_ref,
217            )
218            if not self.dry_run:
219                _write_patch(pe.title(), contents, pe.patch_path())
220            new_patch_entries.append(pe)
221        return new_patch_entries
222
223    def _make_patches_from_pr(
224        self, patch_source: LLVMPullRequest
225    ) -> List[patch_utils.PatchEntry]:
226        json_response = get_llvm_github_pull(patch_source.number)
227        github_ctx = GitHubPRContext(json_response, self.llvm_project_dir)
228        rel_patch_path = f"{github_ctx.full_title_cleaned}.patch"
229        contents, packages = github_ctx.git_squash_chain_patch()
230        new_patch_entries = []
231        for workdir in self._workdirs_for_packages(packages):
232            pe = patch_utils.PatchEntry(
233                workdir=workdir,
234                metadata={
235                    "title": github_ctx.full_title,
236                    "info": [],
237                },
238                rel_patch_path=rel_patch_path,
239                platforms=list(self.platforms),
240                version_range={
241                    "from": self.start_ref.to_rev(self.llvm_project_dir).number,
242                    "until": None,
243                },
244            )
245            # Before we actually do any modifications, check if the patch is
246            # already applied.
247            if self.is_patch_applied(pe):
248                raise CherrypickError(
249                    f"Patch at {pe.rel_patch_path}"
250                    " already exists in PATCHES.json"
251                )
252            if not self.dry_run:
253                _write_patch(pe.title(), contents, pe.patch_path())
254            new_patch_entries.append(pe)
255        return new_patch_entries
256
257    def _workdirs_for_packages(self, packages: Iterable[Path]) -> List[Path]:
258        return [self.chromiumos_root / pkg / "files" for pkg in packages]
259
260    def is_patch_applied(self, to_check: patch_utils.PatchEntry) -> bool:
261        """Return True if the patch is applied in PATCHES.json."""
262        patches_json_file = to_check.workdir / PATCH_METADATA_FILENAME
263        with patches_json_file.open(encoding="utf-8") as f:
264            patch_entries = patch_utils.json_to_patch_entries(
265                to_check.workdir, f
266            )
267        return any(
268            p.rel_patch_path == to_check.rel_patch_path for p in patch_entries
269        )
270
271    def _is_valid_patch_range(
272        self, from_ref: LLVMGitRef, to_ref: LLVMGitRef
273    ) -> bool:
274        return (
275            from_ref.to_rev(self.llvm_project_dir).number
276            < to_ref.to_rev(self.llvm_project_dir).number
277        )
278
279
280def get_commit_subj(git_root_dir: Path, ref: str) -> str:
281    """Return a given commit's subject."""
282    logging.debug("Getting commit subject for %s", ref)
283    subj = subprocess.run(
284        ["git", "show", "-s", "--format=%s", ref],
285        cwd=git_root_dir,
286        encoding="utf-8",
287        stdout=subprocess.PIPE,
288        check=True,
289    ).stdout.strip()
290    logging.debug("  -> %s", subj)
291    return subj
292
293
294def git_format_patch(git_root_dir: Path, ref: str) -> str:
295    """Format a patch for a single git ref.
296
297    Args:
298        git_root_dir: Root directory for a given local git repository.
299        ref: Git ref to make a patch for.
300
301    Returns:
302        The patch file contents.
303    """
304    logging.debug("Formatting patch for %s^..%s", ref, ref)
305    proc = subprocess.run(
306        ["git", "format-patch", "--stdout", f"{ref}^..{ref}"],
307        cwd=git_root_dir,
308        encoding="utf-8",
309        stdout=subprocess.PIPE,
310        check=True,
311    )
312    contents = proc.stdout.strip()
313    if not contents:
314        raise ValueError(f"No git diff between {ref}^..{ref}")
315    logging.debug("Patch diff is %d lines long", contents.count("\n"))
316    return contents
317
318
319def get_llvm_github_pull(pull_number: int) -> Dict[str, Any]:
320    """Get information about an LLVM pull request.
321
322    Returns:
323        A dictionary containing the JSON response from GitHub.
324
325    Raises:
326        RuntimeError when the network response is not OK.
327    """
328
329    pull_url = (
330        f"https://api.github.com/repos/llvm/llvm-project/pulls/{pull_number}"
331    )
332    # TODO(ajordanr): If we are ever allowed to use the 'requests' library
333    # we should move to that instead of urllib.
334    req = request.Request(
335        url=pull_url,
336        headers={
337            "X-GitHub-Api-Version": "2022-11-28",
338            "Accept": "application/vnd.github+json",
339        },
340    )
341    with request.urlopen(req) as f:
342        if f.status >= 400:
343            raise RuntimeError(
344                f"GitHub response was not OK: {f.status} {f.reason}"
345            )
346        response = f.read().decode("utf-8")
347    return json.loads(response)
348
349
350class GitHubPRContext:
351    """Metadata and pathing context for a GitHub pull request checkout."""
352
353    def __init__(
354        self,
355        response: Dict[str, Any],
356        llvm_project_dir: Path,
357    ) -> None:
358        """Create a GitHubPRContext from a GitHub pulls api call.
359
360        Args:
361            response: A dictionary formed from the JSON sent by
362                the github pulls API endpoint.
363            llvm_project_dir: Path to llvm-project git directory.
364        """
365        try:
366            self.clone_url = response["head"]["repo"]["clone_url"]
367            self._title = response["title"]
368            self.body = response["body"]
369            self.base_ref = response["base"]["sha"]
370            self.head_ref = response["head"]["sha"]
371            self.llvm_project_dir = llvm_project_dir
372            self.number = int(response["number"])
373            self._fetched = False
374        except (ValueError, KeyError):
375            logging.error("Failed to parse GitHub response:\n%s", response)
376            raise
377
378    @property
379    def full_title(self) -> str:
380        return f"[PR{self.number}] {self._title}"
381
382    @property
383    def full_title_cleaned(self) -> str:
384        return re.sub(r"\W", "-", self.full_title)
385
386    def git_squash_chain_patch(self) -> Tuple[str, Set[Path]]:
387        """Replicate a squashed merge commit as a patch file.
388
389        Args:
390            git_root_dir: Root directory for a given local git repository
391                which contains the base_ref.
392            output: File path to write the patch to.
393
394        Returns:
395            The patch file contents.
396        """
397        self._fetch()
398        idx = random.randint(0, 2**32)
399        tmpbranch_name = f"squash-branch-{idx}"
400
401        with tempfile.TemporaryDirectory() as dir_str:
402            worktree_parent_dir = Path(dir_str)
403            commit_message_file = worktree_parent_dir / "commit_message"
404            # Need this separate from the commit message, otherwise the
405            # dir will be non-empty.
406            worktree_dir = worktree_parent_dir / "worktree"
407            with commit_message_file.open("w", encoding="utf-8") as f:
408                f.write(self.full_title)
409                f.write("\n\n")
410                f.write(
411                    "\n".join(
412                        textwrap.wrap(
413                            self.body, width=72, replace_whitespace=False
414                        )
415                    )
416                )
417                f.write("\n")
418
419            logging.debug("Base ref: %s", self.base_ref)
420            logging.debug("Head ref: %s", self.head_ref)
421            logging.debug(
422                "Creating worktree at '%s' with branch '%s'",
423                worktree_dir,
424                tmpbranch_name,
425            )
426            self._run(
427                [
428                    "git",
429                    "worktree",
430                    "add",
431                    "-b",
432                    tmpbranch_name,
433                    worktree_dir,
434                    self.base_ref,
435                ],
436                self.llvm_project_dir,
437            )
438            try:
439                self._run(
440                    ["git", "merge", "--squash", self.head_ref], worktree_dir
441                )
442                self._run(
443                    [
444                        "git",
445                        "commit",
446                        "-a",
447                        "-F",
448                        commit_message_file,
449                    ],
450                    worktree_dir,
451                )
452                changed_packages = get_changed_packages(
453                    worktree_dir, (self.base_ref, "HEAD")
454                )
455                patch_contents = git_format_patch(worktree_dir, "HEAD")
456            finally:
457                logging.debug(
458                    "Cleaning up worktree and deleting branch %s",
459                    tmpbranch_name,
460                )
461                self._run(
462                    ["git", "worktree", "remove", worktree_dir],
463                    self.llvm_project_dir,
464                )
465                self._run(
466                    ["git", "branch", "-D", tmpbranch_name],
467                    self.llvm_project_dir,
468                )
469        return (patch_contents, changed_packages)
470
471    def _fetch(self) -> None:
472        if not self._fetched:
473            logging.debug(
474                "Fetching from %s and setting FETCH_HEAD to %s",
475                self.clone_url,
476                self.head_ref,
477            )
478            self._run(
479                ["git", "fetch", self.clone_url, self.head_ref],
480                cwd=self.llvm_project_dir,
481            )
482            self._fetched = True
483
484    @staticmethod
485    def _run(
486        cmd: List[Union[str, Path]],
487        cwd: Path,
488        stdin: int = subprocess.DEVNULL,
489    ) -> subprocess.CompletedProcess:
490        """Helper for subprocess.run."""
491        return subprocess.run(
492            cmd,
493            cwd=cwd,
494            stdin=stdin,
495            stdout=subprocess.PIPE,
496            encoding="utf-8",
497            check=True,
498        )
499
500
501def get_changed_packages(
502    llvm_project_dir: Path, ref: Union[str, Tuple[str, str]]
503) -> Set[Path]:
504    """Returns package paths which changed over a given ref.
505
506    Args:
507        llvm_project_dir: Path to llvm-project
508        ref: Git ref to check diff of. If set to a tuple, compares the diff
509            between the first and second ref.
510
511    Returns:
512        A set of package paths which were changed.
513    """
514    if isinstance(ref, tuple):
515        ref_from, ref_to = ref
516    elif isinstance(ref, str):
517        ref_from = ref + "^"
518        ref_to = ref
519    else:
520        raise TypeError(f"ref was {type(ref)}; need a tuple or a string")
521
522    logging.debug("Getting git diff between %s..%s", ref_from, ref_to)
523    proc = subprocess.run(
524        ["git", "diff", "--name-only", f"{ref_from}..{ref_to}"],
525        check=True,
526        encoding="utf-8",
527        stdout=subprocess.PIPE,
528        cwd=llvm_project_dir,
529    )
530    changed_paths = proc.stdout.splitlines()
531    logging.debug("Found %d changed files", len(changed_paths))
532    # Some LLVM projects are built by LLVM ebuild on x86, so always apply the
533    # patch to LLVM ebuild
534    packages = {LLVM_PKG_PATH}
535    for changed_path in changed_paths:
536        if changed_path.startswith("compiler-rt"):
537            packages.add(COMPILER_RT_PKG_PATH)
538            if "scudo" in changed_path:
539                packages.add(SCUDO_PKG_PATH)
540        elif changed_path.startswith("libunwind"):
541            packages.add(LIBUNWIND_PKG_PATH)
542        elif changed_path.startswith("libcxx") or changed_path.startswith(
543            "libcxxabi"
544        ):
545            packages.add(LIBCXX_PKG_PATH)
546        elif changed_path.startswith("lldb"):
547            packages.add(LLDB_PKG_PATH)
548    return packages
549
550
551def _has_repo_child(path: Path) -> bool:
552    """Check if a given directory has a repo child.
553
554    Useful for checking if a directory has a chromiumos source tree.
555    """
556    child_maybe = path / ".repo"
557    return path.is_dir() and child_maybe.is_dir()
558
559
560def _autodetect_chromiumos_root(
561    parent: Optional[Path] = None,
562) -> Optional[Path]:
563    """Find the root of the chromiumos source tree from the current workdir.
564
565    Returns:
566        The root directory of the current chromiumos source tree.
567        If the current working directory is not within a chromiumos source
568        tree, then this returns None.
569    """
570    if parent is None:
571        parent = Path.cwd()
572    if parent.resolve() == Path.root:
573        return None
574    if _has_repo_child(parent):
575        return parent
576    return _autodetect_chromiumos_root(parent.parent)
577
578
579def _write_patch(title: str, contents: str, path: Path) -> None:
580    """Actually write the patch contents to a file."""
581    # This is mostly separated for mocking.
582    logging.info("Writing patch '%s' to '%s'", title, path)
583    path.write_text(contents, encoding="utf-8")
584
585
586def validate_patch_args(
587    positional_args: List[str],
588) -> List[Union[LLVMGitRef, LLVMPullRequest]]:
589    """Checks that each ref_or_pr_num is in a valid format."""
590    patch_sources = []
591    for arg in positional_args:
592        patch_source: Union[LLVMGitRef, LLVMPullRequest]
593        if arg.startswith("p:"):
594            try:
595                pull_request_num = int(arg.lstrip("p:"))
596            except ValueError as e:
597                raise ValueError(
598                    f"GitHub Pull Request '{arg}' was not in the format of"
599                    f" 'p:NNNN': {e}"
600                )
601            logging.info("Patching remote GitHub PR '%s'", pull_request_num)
602            patch_source = LLVMPullRequest(pull_request_num)
603        else:
604            logging.info("Patching local ref '%s'", arg)
605            patch_source = LLVMGitRef(arg)
606        patch_sources.append(patch_source)
607    return patch_sources
608
609
610def parse_args() -> argparse.Namespace:
611    """Parse CLI arguments for this script."""
612
613    parser = argparse.ArgumentParser(
614        "get_patch",
615        description=__doc__,
616        formatter_class=argparse.RawDescriptionHelpFormatter,
617    )
618    parser.add_argument(
619        "-c",
620        "--chromiumos-root",
621        help="""Path to the chromiumos source tree root.
622        Tries to autodetect if not passed.
623        """,
624    )
625    parser.add_argument(
626        "-l",
627        "--llvm",
628        help="""Path to the llvm dir.
629        Tries to autodetect from chromiumos root if not passed.
630        """,
631    )
632    parser.add_argument(
633        "-s",
634        "--start-ref",
635        default="HEAD",
636        help="""The starting ref for which to apply patches.
637        """,
638    )
639    parser.add_argument(
640        "-p",
641        "--platform",
642        action="append",
643        help="""Apply this patch to the give platform. Common options include
644        'chromiumos' and 'android'. Can be specified multiple times to
645        apply to multiple platforms. If not passed, platform is set to
646        'chromiumos'.
647        """,
648    )
649    parser.add_argument(
650        "--dry-run",
651        action="store_true",
652        help="Run normally, but don't make any changes. Read-only mode.",
653    )
654    parser.add_argument(
655        "-v",
656        "--verbose",
657        action="store_true",
658        help="Enable verbose logging.",
659    )
660    parser.add_argument(
661        "ref_or_pr_num",
662        nargs="+",
663        help="""Git ref or GitHub PR number to make patches.
664        To patch a GitHub PR, use the syntax p:NNNN (e.g. 'p:123456').
665        """,
666        type=str,
667    )
668    args = parser.parse_args()
669
670    logging.basicConfig(
671        format=">> %(asctime)s: %(levelname)s: %(filename)s:%(lineno)d: "
672        "%(message)s",
673        level=logging.DEBUG if args.verbose else logging.INFO,
674    )
675
676    args.patch_sources = validate_patch_args(args.ref_or_pr_num)
677    if args.chromiumos_root:
678        if not _has_repo_child(args.chromiumos_root):
679            parser.error("chromiumos root directly passed but has no .repo")
680        logging.debug("chromiumos root directly passed; found and verified")
681    elif tmp := _autodetect_chromiumos_root():
682        logging.debug("chromiumos root autodetected; found and verified")
683        args.chromiumos_root = tmp
684    else:
685        parser.error(
686            "Could not autodetect chromiumos root. Use '-c' to pass the "
687            "chromiumos root path directly."
688        )
689
690    if not args.llvm:
691        if (args.chromiumos_root / LLVM_PROJECT_PATH).is_dir():
692            args.llvm = args.chromiumos_root / LLVM_PROJECT_PATH
693        else:
694            parser.error(
695                "Could not autodetect llvm-project dir. Use '-l' to pass the "
696                "llvm-project directly"
697            )
698    return args
699
700
701def main() -> None:
702    """Entry point for the program."""
703
704    args = parse_args()
705
706    # For the vast majority of cases, we'll only want to set platform to
707    # ["chromiumos"], so let's make that the default.
708    platforms: List[str] = args.platform if args.platform else ["chromiumos"]
709
710    ctx = PatchContext(
711        chromiumos_root=args.chromiumos_root,
712        llvm_project_dir=args.llvm,
713        start_ref=LLVMGitRef(args.start_ref),
714        platforms=platforms,
715        dry_run=args.dry_run,
716    )
717    for patch_source in args.patch_sources:
718        ctx.apply_patches(patch_source)
719
720
721if __name__ == "__main__":
722    main()
723