xref: /aosp_15_r20/external/toolchain-utils/seccomp_tools/mass_seccomp_editor/mass_seccomp_editor.py (revision 760c253c1ed00ce9abd48f8546f08516e57485fe)
1#!/usr/bin/env python3
2
3# Copyright 2021 The ChromiumOS Authors
4# Use of this source code is governed by a BSD-style license that can be
5# found in the LICENSE file.
6
7"""Script to make mass, CrOS-wide seccomp changes."""
8
9import argparse
10import re
11import subprocess
12import sys
13import shutil
14from typing import Any, Iterable, Optional
15from dataclasses import dataclass, field
16
17# Pre-compiled regexes.
18AMD64_RE = re.compile(r".*(amd|x86_)64.*\.policy")
19X86_RE = re.compile(r".*x86.*\.policy")
20AARCH64_RE = re.compile(r".*a(arch|rm)64.*\.policy")
21ARM_RE = re.compile(r".*arm(v7)?.*\.policy")
22
23
24@dataclass(frozen=True)
25class Policies:
26    """Dataclass to hold lists of policies which match certain types."""
27
28    arm: list[str] = field(default_factory=list)
29    x86_64: list[str] = field(default_factory=list)
30    x86: list[str] = field(default_factory=list)
31    arm64: list[str] = field(default_factory=list)
32    none: list[str] = field(default_factory=list)
33
34    def to_dict(self) -> dict[str, list[str]]:
35        """Convert this class to a dictionary."""
36        return {**self.__dict__}
37
38
39def main():
40    """Run the program from cmd line"""
41    args = parse_args()
42    if all(x is None for x in [args.all, args.b64, args.b32, args.none]):
43        print(
44            "Require at least one of {--all, --b64, --b32, --none}",
45            file=sys.stderr,
46        )
47        sys.exit(1)
48    matches, success = find_potential_policy_files(args.packages)
49
50    separated = Policies()
51
52    for m in matches:
53        if AMD64_RE.match(m):
54            separated.x86_64.append(m)
55            continue
56        if X86_RE.match(m):
57            separated.x86.append(m)
58            continue
59        if AARCH64_RE.match(m):
60            separated.arm64.append(m)
61            continue
62        if ARM_RE.match(m):
63            separated.arm.append(m)
64            continue
65        separated.none.append(m)
66
67    syscall_lookup_table = _make_syscall_lookup_table(args)
68
69    for (type_, val) in separated.to_dict().items():
70        for fp in val:
71            syscalls = syscall_lookup_table[type_]
72            missing = check_missing_syscalls(syscalls, fp)
73            if missing is None:
74                print(f"E ({type_}) {fp}")
75            elif len(missing) == 0:
76                print(f"_ ({type_}) {fp}")
77            else:
78                missing_str = ",".join(missing)
79                print(f"M ({type_}) {fp} :: {missing_str}")
80
81    if not args.edit:
82        sys.exit(0 if success else 2)
83
84    for (type_, val) in separated.to_dict().items():
85        for fp in val:
86            syscalls = syscall_lookup_table[type_]
87            if args.force:
88                _confirm_add(fp, syscalls, args.yes)
89                continue
90            missing = check_missing_syscalls(syscalls, fp)
91            if missing is None or len(missing) == 0:
92                print(f"Already good for {fp} ({type_})")
93            else:
94                _confirm_add(fp, missing, args.yes)
95
96    sys.exit(0 if success else 2)
97
98
99def _make_syscall_lookup_table(args: Any) -> dict[str, list[str]]:
100    """Make lookup table, segmented by all/b32/b64/none policies.
101
102    Args:
103      args: Direct output from parse_args.
104
105    Returns:
106      dict of syscalls we want to search for in each policy file,
107      where the key is the policy file arch, and the value is
108      a list of syscalls as strings.
109    """
110    syscall_lookup_table = Policies().to_dict()
111    if args.all:
112        split_syscalls = [x.strip() for x in args.all.split(",")]
113        for v in syscall_lookup_table.values():
114            v.extend(split_syscalls)
115    if args.b32:
116        split_syscalls = [x.strip() for x in args.b32.split(",")]
117        syscall_lookup_table["x86"].extend(split_syscalls)
118        syscall_lookup_table["arm"].extend(split_syscalls)
119    if args.b64:
120        split_syscalls = [x.strip() for x in args.b64.split(",")]
121        syscall_lookup_table["x86_64"].extend(split_syscalls)
122        syscall_lookup_table["arm64"].extend(split_syscalls)
123    if args.none:
124        split_syscalls = [x.strip() for x in args.none.split(",")]
125        syscall_lookup_table["none"].extend(split_syscalls)
126    return syscall_lookup_table
127
128
129def _confirm_add(fp: str, syscalls: Iterable[str], noninteractive=None):
130    """Interactive confirmation check you wish to add a syscall.
131
132    Args:
133      fp: filepath of the file to edit.
134      syscalls: list-like of syscalls to add to append to the files.
135      noninteractive: Just add the syscalls without asking.
136    """
137    if noninteractive:
138        _update_seccomp(fp, list(syscalls))
139        return
140    syscalls_str = ",".join(syscalls)
141    user_input = input(f"Add {syscalls_str} for {fp}? [y/N]> ")
142    if user_input.lower().startswith("y"):
143        _update_seccomp(fp, list(syscalls))
144        print("Edited!")
145    else:
146        print(f"Skipping {fp}")
147
148
149def check_missing_syscalls(syscalls: list[str], fp: str) -> Optional[set[str]]:
150    """Return which specified syscalls are missing in the given file."""
151    missing_syscalls = set(syscalls)
152    with open(fp) as f:
153        try:
154            lines = f.readlines()
155            for syscall in syscalls:
156                for line in lines:
157                    if re.match(syscall + r":\s*1", line):
158                        missing_syscalls.remove(syscall)
159        except UnicodeDecodeError:
160            return None
161    return missing_syscalls
162
163
164def _update_seccomp(fp: str, missing_syscalls: list[str]):
165    """Update the seccomp of the file based on the seccomp change type."""
166    with open(fp, "a") as f:
167        sorted_syscalls = sorted(missing_syscalls)
168        for to_write in sorted_syscalls:
169            f.write(to_write + ": 1\n")
170
171
172def _search_cmd(query: str, use_fd=True) -> list[str]:
173    if use_fd and shutil.which("fdfind") is not None:
174        return [
175            "fdfind",
176            "-t",
177            "f",
178            "--full-path",
179            f"^.*{query}.*\\.policy$",
180        ]
181    return [
182        "find",
183        ".",
184        "-regex",
185        f"^.*{query}.*\\.policy$",
186        "-type",
187        "f",
188    ]
189
190
191def find_potential_policy_files(packages: list[str]) -> tuple[list[str], bool]:
192    """Find potentially related policy files to the given packages.
193
194    Returns:
195      (policy_files, successful): A list of policy file paths, and a boolean
196      indicating whether all queries were successful in finding at least
197      one related policy file.
198    """
199    all_queries_succeeded = True
200    matches = []
201    for p in packages:
202        # It's quite common that hyphens are translated to underscores
203        # and similarly common that underscores are translated to hyphens.
204        # We make them agnostic here.
205        hyphen_agnostic = re.sub(r"[-_]", "[-_]", p)
206        cmd = subprocess.run(
207            _search_cmd(hyphen_agnostic),
208            stdout=subprocess.PIPE,
209            check=True,
210        )
211        new_matches = [a for a in cmd.stdout.decode("utf-8").split("\n") if a]
212        if not new_matches:
213            print(f"WARNING: No matches found for {p}", file=sys.stderr)
214            all_queries_succeeded = False
215        else:
216            matches.extend(new_matches)
217    return matches, all_queries_succeeded
218
219
220def parse_args() -> Any:
221    """Handle command line arguments."""
222    parser = argparse.ArgumentParser(
223        description="Check for missing syscalls in"
224        " seccomp policy files, or make"
225        " mass seccomp changes.\n\n"
226        "The format of this output follows the template:\n"
227        "    status (arch) local/policy/filepath :: syscall,syscall,syscall\n"
228        'Where the status can be "_" for present, "M" for missing,'
229        ' or "E" for Error\n\n'
230        "Example:\n"
231        "    mass_seccomp_editor.py --all fstatfs --b32 fstatfs64"
232        " modemmanager\n\n"
233        "Exit Codes:\n"
234        "    '0' for successfully found specific policy files\n"
235        "    '1' for python-related error.\n"
236        "    '2' for no matched policy files for a given query.",
237        formatter_class=argparse.RawTextHelpFormatter,
238    )
239    parser.add_argument("packages", nargs="+")
240    parser.add_argument(
241        "--all",
242        type=str,
243        metavar="syscalls",
244        help="comma separated syscalls to check in all policy files",
245    )
246    parser.add_argument(
247        "--b64",
248        type=str,
249        metavar="syscalls",
250        help="Comma separated syscalls to check in 64bit architectures",
251    )
252    parser.add_argument(
253        "--b32",
254        type=str,
255        metavar="syscalls",
256        help="Comma separated syscalls to check in 32bit architectures",
257    )
258    parser.add_argument(
259        "--none",
260        type=str,
261        metavar="syscalls",
262        help="Comma separated syscalls to check in unknown architectures",
263    )
264    parser.add_argument(
265        "--edit",
266        action="store_true",
267        help="Make changes to the listed files,"
268        " rather than just printing out what is missing",
269    )
270    parser.add_argument(
271        "-y",
272        "--yes",
273        action="store_true",
274        help='Say "Y" to all interactive checks',
275    )
276    parser.add_argument(
277        "--force",
278        action="store_true",
279        help="Edit all files, regardless of missing status."
280        " Does nothing without --edit.",
281    )
282    return parser.parse_args()
283
284
285if __name__ == "__main__":
286    main()
287