xref: /aosp_15_r20/external/pigweed/pw_tokenizer/py/pw_tokenizer/database.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Creates and manages token databases.
16
17This module manages reading tokenized strings from ELF files and building and
18maintaining token databases.
19"""
20
21import argparse
22from datetime import datetime
23import glob
24import itertools
25import json
26import logging
27import os
28from pathlib import Path
29import re
30import struct
31import sys
32from typing import (
33    cast,
34    Any,
35    Callable,
36    Iterable,
37    Iterator,
38    Pattern,
39    Set,
40    TextIO,
41)
42
43try:
44    from pw_tokenizer import elf_reader, tokens
45except ImportError:
46    # Append this path to the module search path to allow running this module
47    # without installing the pw_tokenizer package.
48    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
49    from pw_tokenizer import elf_reader, tokens
50
51_LOG = logging.getLogger('pw_tokenizer')
52
53
54def _elf_reader(elf) -> elf_reader.Elf:
55    return elf if isinstance(elf, elf_reader.Elf) else elf_reader.Elf(elf)
56
57
58# Magic number used to indicate the beginning of a tokenized string entry. This
59# value MUST match the value of _PW_TOKENIZER_ENTRY_MAGIC in
60# pw_tokenizer/public/pw_tokenizer/internal/tokenize_string.h.
61_TOKENIZED_ENTRY_MAGIC = 0xBAA98DEE
62_ENTRY = struct.Struct('<4I')
63_TOKENIZED_ENTRY_SECTIONS = re.compile(r'^\.pw_tokenizer.entries(?:\.[_\d]+)?$')
64
65_ERROR_HANDLER = 'surrogateescape'  # How to deal with UTF-8 decoding errors
66
67
68class Error(Exception):
69    """Failed to extract token entries from an ELF file."""
70
71
72def _read_tokenized_entries(
73    data: bytes, domain: Pattern[str]
74) -> Iterator[tokens.TokenizedStringEntry]:
75    index = 0
76
77    while index + _ENTRY.size <= len(data):
78        magic, token, domain_len, string_len = _ENTRY.unpack_from(data, index)
79
80        if magic != _TOKENIZED_ENTRY_MAGIC:
81            raise Error(
82                f'Expected magic number 0x{_TOKENIZED_ENTRY_MAGIC:08x}, '
83                f'found 0x{magic:08x}'
84            )
85
86        start = index + _ENTRY.size
87        index = start + domain_len + string_len
88
89        # Create the entries, trimming null terminators.
90        entry = tokens.TokenizedStringEntry(
91            token,
92            data[start + domain_len : index - 1].decode(errors=_ERROR_HANDLER),
93            data[start : start + domain_len - 1].decode(errors=_ERROR_HANDLER),
94        )
95
96        if data[start + domain_len - 1] != 0:
97            raise Error(
98                f'Domain {entry.domain} for {entry.string} not null terminated'
99            )
100
101        if data[index - 1] != 0:
102            raise Error(f'String {entry.string} is not null terminated')
103
104        if domain.fullmatch(entry.domain):
105            yield entry
106
107
108def _database_from_elf(elf, domain: Pattern[str]) -> tokens.Database:
109    """Reads the tokenized strings from an elf_reader.Elf or ELF file object."""
110    _LOG.debug('Reading tokenized strings in domain "%s" from %s', domain, elf)
111
112    reader = _elf_reader(elf)
113
114    # Read tokenized string entries.
115    section_data = reader.dump_section_contents(_TOKENIZED_ENTRY_SECTIONS)
116    if section_data is not None:
117        return tokens.Database(_read_tokenized_entries(section_data, domain))
118
119    return tokens.Database([])
120
121
122def tokenization_domains(elf) -> Iterator[str]:
123    """Lists all tokenization domains in an ELF file."""
124    reader = _elf_reader(elf)
125    section_data = reader.dump_section_contents(_TOKENIZED_ENTRY_SECTIONS)
126    if section_data is not None:
127        yield from frozenset(
128            e.domain
129            for e in _read_tokenized_entries(section_data, re.compile('.*'))
130        )
131
132
133def read_tokenizer_metadata(elf) -> dict[str, int]:
134    """Reads the metadata entries from an ELF."""
135    sections = _elf_reader(elf).dump_section_contents(r'\.pw_tokenizer\.info')
136
137    metadata: dict[str, int] = {}
138    if sections is not None:
139        for key, value in struct.iter_unpack('12sI', sections):
140            try:
141                metadata[key.rstrip(b'\0').decode()] = value
142            except UnicodeDecodeError as err:
143                _LOG.error(
144                    'Failed to decode metadata key %r: %s',
145                    key.rstrip(b'\0'),
146                    err,
147                )
148
149    return metadata
150
151
152def _database_from_strings(strings: list[str]) -> tokens.Database:
153    """Generates a C and C++ compatible database from untokenized strings."""
154    # Generate a C-compatible database from the fixed length hash.
155    c_db = tokens.Database.from_strings(strings, tokenize=tokens.c_hash)
156
157    # Generate a C++ compatible database by allowing the hash to follow the
158    # string length.
159    cpp_db = tokens.Database.from_strings(
160        strings, tokenize=tokens.pw_tokenizer_65599_hash
161    )
162
163    # Use a union of the C and C++ compatible databases.
164    return tokens.Database.merged(c_db, cpp_db)
165
166
167def _database_from_json(fd) -> tokens.Database:
168    return _database_from_strings(json.load(fd))
169
170
171def _load_token_database(  # pylint: disable=too-many-return-statements
172    db, domain: Pattern[str]
173) -> tokens.Database:
174    """Loads a Database from supported database types.
175
176    Supports Database objects, JSONs, ELFs, CSVs, and binary databases.
177    """
178    if db is None:
179        return tokens.Database()
180
181    if isinstance(db, tokens.Database):
182        return db
183
184    if isinstance(db, elf_reader.Elf):
185        return _database_from_elf(db, domain)
186
187    # If it's a str, it might be a path. Check if it's an ELF, CSV, or JSON.
188    if isinstance(db, (str, Path)):
189        if not os.path.exists(db):
190            raise FileNotFoundError(f'"{db}" is not a path to a token database')
191
192        if Path(db).is_dir():
193            return tokens.DatabaseFile.load(Path(db))
194
195        # Read the path as an ELF file.
196        with open(db, 'rb') as fd:
197            if elf_reader.compatible_file(fd):
198                return _database_from_elf(fd, domain)
199
200        # Generate a database from JSON.
201        if str(db).endswith('.json'):
202            with open(db, 'r', encoding='utf-8') as json_fd:
203                return _database_from_json(json_fd)
204
205        # Read the path as a packed binary or CSV file.
206        return tokens.DatabaseFile.load(Path(db))
207
208    # Assume that it's a file object and check if it's an ELF.
209    if elf_reader.compatible_file(db):
210        return _database_from_elf(db, domain)
211
212    # Read the database as JSON, CSV, or packed binary from a file object's
213    # path.
214    if hasattr(db, 'name') and os.path.exists(db.name):
215        if db.name.endswith('.json'):
216            return _database_from_json(db)
217
218        return tokens.DatabaseFile.load(Path(db.name))
219
220    # Read CSV directly from the file object.
221    return tokens.Database(tokens.parse_csv(db))
222
223
224def load_token_database(
225    *databases, domain: Pattern[str] = re.compile('.*')  # Load all by default
226) -> tokens.Database:
227    """Loads a Database from supported database types.
228
229    Supports Database objects, JSONs, ELFs, CSVs, and binary databases.
230    """
231    return tokens.Database.merged(
232        *(_load_token_database(db, domain) for db in databases)
233    )
234
235
236def database_summary(db: tokens.Database) -> dict[str, Any]:
237    """Returns a simple report of properties of the database."""
238    present = [entry for entry in db.entries() if not entry.date_removed]
239    collisions = {
240        token: list(e.string for e in entries)
241        for token, entries in db.collisions()
242    }
243
244    # Add 1 to each string's size to account for the null terminator.
245    return dict(
246        present_entries=len(present),
247        present_size_bytes=sum(len(entry.string) + 1 for entry in present),
248        total_entries=len(db.entries()),
249        total_size_bytes=sum(len(entry.string) + 1 for entry in db.entries()),
250        collisions=collisions,
251    )
252
253
254_DatabaseReport = dict[str, dict[str, dict[str, Any]]]
255
256
257def generate_reports(paths: Iterable[Path]) -> _DatabaseReport:
258    """Returns a dictionary with information about the provided databases."""
259    reports: _DatabaseReport = {}
260
261    for path in paths:
262        domains = ['']
263        if path.is_file():
264            with path.open('rb') as file:
265                if elf_reader.compatible_file(file):
266                    domains = list(tokenization_domains(file))
267
268        domain_reports = {}
269
270        for domain in domains:
271            domain_reports[domain] = database_summary(
272                load_token_database(path, domain=re.compile(domain))
273            )
274
275        reports[str(path)] = domain_reports
276
277    return reports
278
279
280def _handle_create(
281    databases,
282    database: Path,
283    force: bool,
284    output_type: str,
285    include: list,
286    exclude: list,
287    replace: list,
288) -> None:
289    """Creates a token database file from one or more ELF files."""
290    if not force and database.exists():
291        raise FileExistsError(
292            f'The file {database} already exists! Use --force to overwrite.'
293        )
294
295    if not database.parent.exists():
296        database.parent.mkdir(parents=True)
297
298    if output_type == 'directory':
299        if str(database) == '-':
300            raise ValueError(
301                'Cannot specify "-" (stdout) for directory databases'
302            )
303
304        database.mkdir(exist_ok=True)
305        database = database / f'database{tokens.DIR_DB_SUFFIX}'
306        output_type = 'csv'
307
308    if str(database) == '-':
309        # Must write bytes to stdout; use sys.stdout.buffer.
310        fd = sys.stdout.buffer
311    else:
312        fd = database.open('wb')
313
314    db = tokens.Database.merged(*databases)
315    db.filter(include, exclude, replace)
316
317    with fd:
318        if output_type == 'csv':
319            tokens.write_csv(db, fd)
320        elif output_type == 'binary':
321            tokens.write_binary(db, fd)
322        else:
323            raise ValueError(f'Unknown database type "{output_type}"')
324
325    _LOG.info(
326        'Wrote database with %d entries to %s as %s',
327        len(db),
328        fd.name,
329        output_type,
330    )
331
332
333def _handle_add(
334    token_database: tokens.DatabaseFile,
335    databases: list[tokens.Database],
336    commit: str | None,
337) -> None:
338    initial = len(token_database)
339    if commit:
340        entries = itertools.chain.from_iterable(
341            db.entries() for db in databases
342        )
343        token_database.add_and_discard_temporary(entries, commit)
344    else:
345        for source in databases:
346            token_database.add(source.entries())
347
348        token_database.write_to_file()
349
350    number_of_changes = len(token_database) - initial
351
352    if number_of_changes:
353        _LOG.info(
354            'Added %d entries to %s', number_of_changes, token_database.path
355        )
356
357
358def _handle_mark_removed(
359    token_database: tokens.DatabaseFile,
360    databases: list[tokens.Database],
361    date: datetime | None,
362):
363    marked_removed = token_database.mark_removed(
364        (
365            entry
366            for entry in tokens.Database.merged(*databases).entries()
367            if not entry.date_removed
368        ),
369        date,
370    )
371
372    token_database.write_to_file(rewrite=True)
373
374    _LOG.info(
375        'Marked %d of %d entries as removed in %s',
376        len(marked_removed),
377        len(token_database),
378        token_database.path,
379    )
380
381
382def _handle_purge(token_database: tokens.DatabaseFile, before: datetime | None):
383    purged = token_database.purge(before)
384    token_database.write_to_file(rewrite=True)
385
386    _LOG.info('Removed %d entries from %s', len(purged), token_database.path)
387
388
389def _handle_report(token_database_or_elf: list[Path], output: TextIO) -> None:
390    json.dump(generate_reports(token_database_or_elf), output, indent=2)
391    output.write('\n')
392
393
394def expand_paths_or_globs(*paths_or_globs: str) -> Iterable[Path]:
395    """Expands any globs in a list of paths; raises FileNotFoundError."""
396    for path_or_glob in paths_or_globs:
397        if os.path.exists(path_or_glob):
398            # This is a valid path; yield it without evaluating it as a glob.
399            yield Path(path_or_glob)
400        else:
401            paths = glob.glob(path_or_glob, recursive=True)
402
403            # If no paths were found and the path is not a glob, raise an Error.
404            if not paths and not any(c in path_or_glob for c in '*?[]!'):
405                raise FileNotFoundError(f'{path_or_glob} is not a valid path')
406
407            for path in paths:
408                # Resolve globs to JSON, CSV, or compatible binary files.
409                if elf_reader.compatible_file(path) or path.endswith(
410                    ('.csv', '.json')
411                ):
412                    yield Path(path)
413
414
415class ExpandGlobs(argparse.Action):
416    """Argparse action that expands and appends paths."""
417
418    def __call__(self, parser, namespace, values, unused_option_string=None):
419        setattr(namespace, self.dest, list(expand_paths_or_globs(*values)))
420
421
422def _read_elf_with_domain(
423    elf: Path, domain: Pattern[str]
424) -> Iterable[tokens.Database]:
425    for path in expand_paths_or_globs(str(elf)):
426        with path.open('rb') as file:
427            if not elf_reader.compatible_file(file):
428                raise ValueError(
429                    f'{elf} is not an ELF file, '
430                    f'but the "{domain}" domain was specified'
431                )
432
433            yield _database_from_elf(file, domain)
434
435
436def parse_domain(path: Path | str) -> tuple[Path, Pattern[str] | None]:
437    """Extracts an optional domain regex pattern suffix from a path"""
438    path = Path(path)
439    delimiters = path.name.count('#')
440
441    if delimiters == 0:
442        return path, None
443
444    if delimiters == 1:
445        name, domain = path.name.split('#')
446        return path.with_name(name), re.compile(domain)
447
448    raise ValueError(f'{path} has {delimiters} "#" delimiters; expected 0 or 1')
449
450
451def _parse_paths(values: Iterable[str]) -> list[tokens.Database]:
452    databases: list[tokens.Database] = []
453    paths: Set[Path] = set()
454
455    for value in values:
456        path, domain = parse_domain(value)
457        if domain is None:
458            paths.update(expand_paths_or_globs(value))
459        else:
460            databases.extend(_read_elf_with_domain(path, domain))
461
462    for path in paths:
463        databases.append(load_token_database(path))
464
465    return databases
466
467
468class LoadTokenDatabases(argparse.Action):
469    """Argparse action that reads tokenize databases from paths or globs.
470
471    ELF files may have #domain appended to them to specify a tokenization domain
472    other than the default.
473    """
474
475    def __call__(self, parser, namespace, values, option_string=None) -> None:
476        try:
477            setattr(namespace, self.dest, _parse_paths(values))
478            return
479        except (
480            ValueError,
481            FileNotFoundError,
482            tokens.DatabaseFormatError,
483        ) as err:
484            error = str(err)
485
486        parser.error(f'argument elf_or_token_database: {error}')
487
488
489def token_databases_parser(nargs: str = '+') -> argparse.ArgumentParser:
490    """Returns an argument parser for reading token databases.
491
492    These arguments can be added to another parser using the parents arg.
493    """
494    parser = argparse.ArgumentParser(add_help=False)
495    parser.add_argument(
496        'databases',
497        metavar='elf_or_token_database',
498        nargs=nargs,
499        action=LoadTokenDatabases,
500        help=(
501            'ELF or token database files from which to read strings and '
502            'tokens. For ELF files, the tokenization domain to read from '
503            'may specified after the path as #domain_name (e.g. '
504            'foo.elf#TEST_DOMAIN). Unless specified, only the default '
505            'domain ("") is read from ELF files; .* reads all domains. '
506            'Globs are expanded to compatible database files.'
507        ),
508    )
509    return parser
510
511
512def _parse_args() -> tuple[Callable[..., None], argparse.Namespace]:
513    """Parse and return command line arguments."""
514
515    def year_month_day(value: str) -> datetime:
516        if value == 'today':
517            return datetime.now()
518
519        return datetime.fromisoformat(value)
520
521    year_month_day.__name__ = 'year-month-day (YYYY-MM-DD)'
522
523    # Shared command line options.
524    option_db = argparse.ArgumentParser(add_help=False)
525    option_db.add_argument(
526        '-d',
527        '--database',
528        dest='token_database',
529        type=lambda arg: tokens.DatabaseFile.load(Path(arg)),
530        required=True,
531        help='The database file to update.',
532    )
533
534    option_tokens = token_databases_parser('*')
535
536    # Top-level argument parser.
537    parser = argparse.ArgumentParser(
538        description=__doc__,
539        formatter_class=argparse.RawDescriptionHelpFormatter,
540    )
541    parser.set_defaults(handler=lambda **_: parser.print_help())
542
543    subparsers = parser.add_subparsers(
544        help='Tokenized string database management actions:'
545    )
546
547    # The 'create' command creates a database file.
548    subparser = subparsers.add_parser(
549        'create',
550        parents=[option_tokens],
551        help=(
552            'Creates a database with tokenized strings from one or more '
553            'sources.'
554        ),
555    )
556    subparser.set_defaults(handler=_handle_create)
557    subparser.add_argument(
558        '-d',
559        '--database',
560        required=True,
561        type=Path,
562        help='Path to the database file to create; use - for stdout.',
563    )
564    subparser.add_argument(
565        '-t',
566        '--type',
567        dest='output_type',
568        choices=('csv', 'binary', 'directory'),
569        default='csv',
570        help='Which type of database to create. (default: csv)',
571    )
572    subparser.add_argument(
573        '-f',
574        '--force',
575        action='store_true',
576        help='Overwrite the database if it exists.',
577    )
578    subparser.add_argument(
579        '-i',
580        '--include',
581        type=cast(Callable[[str], Pattern[str]], re.compile),
582        default=[],
583        action='append',
584        help=(
585            'If provided, at least one of these regular expressions must '
586            'match for a string to be included in the database.'
587        ),
588    )
589    subparser.add_argument(
590        '-e',
591        '--exclude',
592        type=cast(Callable[[str], Pattern[str]], re.compile),
593        default=[],
594        action='append',
595        help=(
596            'If provided, none of these regular expressions may match for a '
597            'string to be included in the database.'
598        ),
599    )
600
601    unescaped_slash = re.compile(r'(?<!\\)/')
602
603    def replacement(value: str) -> tuple[Pattern, 'str']:
604        try:
605            find, sub = unescaped_slash.split(value, 1)
606        except ValueError as _err:
607            raise argparse.ArgumentTypeError(
608                'replacements must be specified as "search_regex/replacement"'
609            )
610
611        try:
612            return re.compile(find.replace(r'\/', '/')), sub
613        except re.error as err:
614            raise argparse.ArgumentTypeError(
615                f'"{value}" is not a valid regular expression: {err}'
616            )
617
618    subparser.add_argument(
619        '--replace',
620        type=replacement,
621        default=[],
622        action='append',
623        help=(
624            'If provided, replaces text that matches a regular expression. '
625            'This can be used to replace sensitive terms in a token '
626            'database that will be distributed publicly. The expression and '
627            'replacement are specified as "search_regex/replacement". '
628            'Plain slash characters in the regex must be escaped with a '
629            r'backslash (\/). The replacement text may include '
630            'backreferences for captured groups in the regex.'
631        ),
632    )
633
634    # The 'add' command adds strings to a database from a set of ELFs.
635    subparser = subparsers.add_parser(
636        'add',
637        parents=[option_db, option_tokens],
638        help=(
639            'Adds new strings to a database with tokenized strings from a set '
640            'of ELF files or other token databases. Missing entries are NOT '
641            'marked as removed.'
642        ),
643    )
644    subparser.set_defaults(handler=_handle_add)
645    subparser.add_argument(
646        '--discard-temporary',
647        dest='commit',
648        help=(
649            'Deletes temporary tokens in memory and on disk when a CSV exists '
650            'within a commit. Afterwards, new strings are added to the '
651            'database from a set of ELF files or other token databases. '
652            'Missing entries are NOT marked as removed.'
653        ),
654    )
655
656    # The 'mark_removed' command marks removed entries to match a set of ELFs.
657    subparser = subparsers.add_parser(
658        'mark_removed',
659        parents=[option_db, option_tokens],
660        help=(
661            'Updates a database with tokenized strings from a set of strings. '
662            'Strings not present in the set remain in the database but are '
663            'marked as removed. New strings are NOT added.'
664        ),
665    )
666    subparser.set_defaults(handler=_handle_mark_removed)
667    subparser.add_argument(
668        '--date',
669        type=year_month_day,
670        help=(
671            'The removal date to use for all strings. '
672            'May be YYYY-MM-DD or "today". (default: today)'
673        ),
674    )
675
676    # The 'purge' command removes old entries.
677    subparser = subparsers.add_parser(
678        'purge',
679        parents=[option_db],
680        help='Purges removed strings from a database.',
681    )
682    subparser.set_defaults(handler=_handle_purge)
683    subparser.add_argument(
684        '-b',
685        '--before',
686        type=year_month_day,
687        help=(
688            'Delete all entries removed on or before this date. '
689            'May be YYYY-MM-DD or "today".'
690        ),
691    )
692
693    # The 'report' command prints a report about a database.
694    subparser = subparsers.add_parser(
695        'report', help='Prints a report about a database.'
696    )
697    subparser.set_defaults(handler=_handle_report)
698    subparser.add_argument(
699        'token_database_or_elf',
700        nargs='+',
701        action=ExpandGlobs,
702        help=(
703            'The ELF files or token databases about which to generate '
704            'reports.'
705        ),
706    )
707    subparser.add_argument(
708        '-o',
709        '--output',
710        type=argparse.FileType('w'),
711        default=sys.stdout,
712        help='The file to which to write the output; use - for stdout.',
713    )
714
715    args = parser.parse_args()
716
717    handler = args.handler
718    del args.handler
719
720    return handler, args
721
722
723def _init_logging(level: int) -> None:
724    _LOG.setLevel(logging.DEBUG)
725    log_to_stderr = logging.StreamHandler()
726    log_to_stderr.setLevel(level)
727    log_to_stderr.setFormatter(
728        logging.Formatter(
729            fmt='%(asctime)s.%(msecs)03d-%(levelname)s: %(message)s',
730            datefmt='%H:%M:%S',
731        )
732    )
733
734    _LOG.addHandler(log_to_stderr)
735
736
737def _main(handler: Callable[..., None], args: argparse.Namespace) -> int:
738    _init_logging(logging.INFO)
739    handler(**vars(args))
740    return 0
741
742
743if __name__ == '__main__':
744    sys.exit(_main(*_parse_args()))
745