xref: /aosp_15_r20/external/pigweed/pw_tokenizer/py/pw_tokenizer/tokens.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Builds and manages databases of tokenized strings."""
15
16from __future__ import annotations
17
18from abc import abstractmethod
19import bisect
20from collections.abc import (
21    Callable,
22    Iterable,
23    Iterator,
24    Mapping,
25    Sequence,
26    ValuesView,
27)
28import csv
29from dataclasses import dataclass
30from datetime import datetime
31import io
32import logging
33from pathlib import Path
34import re
35import struct
36import subprocess
37from typing import (
38    Any,
39    BinaryIO,
40    IO,
41    NamedTuple,
42    overload,
43    Pattern,
44    TextIO,
45    TypeVar,
46)
47from uuid import uuid4
48
49DEFAULT_DOMAIN = ''
50
51# The default hash length to use for C-style hashes. This value only applies
52# when manually hashing strings to recreate token calculations in C. The C++
53# hash function does not have a maximum length.
54#
55# This MUST match the default value of PW_TOKENIZER_CFG_C_HASH_LENGTH in
56# pw_tokenizer/public/pw_tokenizer/config.h.
57DEFAULT_C_HASH_LENGTH = 128
58
59TOKENIZER_HASH_CONSTANT = 65599
60
61_LOG = logging.getLogger('pw_tokenizer')
62
63
64def _value(char: int | str) -> int:
65    return char if isinstance(char, int) else ord(char)
66
67
68def pw_tokenizer_65599_hash(
69    string: str | bytes, *, hash_length: int | None = None
70) -> int:
71    """Hashes the string with the hash function used to generate tokens in C++.
72
73    This hash function is used calculate tokens from strings in Python. It is
74    not used when extracting tokens from an ELF, since the token is stored in
75    the ELF as part of tokenization.
76    """
77    hash_value = len(string)
78    coefficient = TOKENIZER_HASH_CONSTANT
79
80    for char in string[:hash_length]:
81        hash_value = (hash_value + coefficient * _value(char)) % 2**32
82        coefficient = (coefficient * TOKENIZER_HASH_CONSTANT) % 2**32
83
84    return hash_value
85
86
87def c_hash(
88    string: str | bytes, hash_length: int = DEFAULT_C_HASH_LENGTH
89) -> int:
90    """Hashes the string with the hash function used in C."""
91    return pw_tokenizer_65599_hash(string, hash_length=hash_length)
92
93
94@dataclass(frozen=True, eq=True, order=True)
95class _EntryKey:
96    """Uniquely refers to an entry."""
97
98    domain: str
99    token: int
100    string: str
101
102
103class TokenizedStringEntry:
104    """A tokenized string with its metadata."""
105
106    def __init__(
107        self,
108        token: int,
109        string: str,
110        domain: str = DEFAULT_DOMAIN,
111        date_removed: datetime | None = None,
112    ) -> None:
113        self._key = _EntryKey(
114            ''.join(domain.split()),
115            token,
116            string,
117        )
118        self.date_removed = date_removed
119
120    @property
121    def token(self) -> int:
122        return self._key.token
123
124    @property
125    def string(self) -> str:
126        return self._key.string
127
128    @property
129    def domain(self) -> str:
130        return self._key.domain
131
132    def key(self) -> _EntryKey:
133        """The key determines uniqueness for a tokenized string."""
134        return self._key
135
136    def update_date_removed(self, new_date_removed: datetime | None) -> None:
137        """Sets self.date_removed if the other date is newer."""
138        # No removal date (None) is treated as the newest date.
139        if self.date_removed is None:
140            return
141
142        if new_date_removed is None or new_date_removed > self.date_removed:
143            self.date_removed = new_date_removed
144
145    def __eq__(self, other: Any) -> bool:
146        return (
147            self.key() == other.key()
148            and self.date_removed == other.date_removed
149        )
150
151    def __lt__(self, other: Any) -> bool:
152        """Sorts the entry by domain, token, date removed, then string."""
153        if self.domain != other.domain:
154            return self.domain < other.domain
155
156        if self.token != other.token:
157            return self.token < other.token
158
159        # Sort removal dates in reverse, so the most recently removed (or still
160        # present) entry appears first.
161        if self.date_removed != other.date_removed:
162            return (other.date_removed or datetime.max) < (
163                self.date_removed or datetime.max
164            )
165
166        return self.string < other.string
167
168    def __str__(self) -> str:
169        return self.string
170
171    def __repr__(self) -> str:
172        return (
173            f'{self.__class__.__name__}(token=0x{self.token:08x}, '
174            f'string={self.string!r}, domain={self.domain!r})'
175        )
176
177
178_TokenToEntries = dict[int, list[TokenizedStringEntry]]
179_K = TypeVar('_K')
180_V = TypeVar('_V')
181_T = TypeVar('_T')
182
183
184class _TokenDatabaseView(Mapping[_K, _V]):  # pylint: disable=abstract-method
185    """Read-only mapping view of a token database.
186
187    Behaves like a read-only version of defaultdict(list).
188    """
189
190    def __init__(self, mapping: Mapping[_K, Any]) -> None:
191        self._mapping = mapping
192
193    def __contains__(self, key: object) -> bool:
194        return key in self._mapping
195
196    @overload
197    def get(self, key: _K) -> _V | None:  # pylint: disable=arguments-differ
198        ...
199
200    @overload
201    def get(self, key: _K, default: _T) -> _V | _T:  # pylint: disable=W0222
202        ...
203
204    def get(self, key: _K, default: _T | None = None) -> _V | _T | None:
205        return self._mapping.get(key, default)
206
207    def __iter__(self) -> Iterator[_K]:
208        return iter(self._mapping)
209
210    def __len__(self) -> int:
211        return len(self._mapping)
212
213    def __str__(self) -> str:
214        return str(self._mapping)
215
216    def __repr__(self) -> str:
217        return repr(self._mapping)
218
219
220class _TokenMapping(_TokenDatabaseView[int, Sequence[TokenizedStringEntry]]):
221    def __getitem__(self, token: int) -> Sequence[TokenizedStringEntry]:
222        """Returns strings that match the specified token; may be empty."""
223        return self._mapping.get(token, ())  # Empty sequence if no match
224
225
226class _DomainTokenMapping(_TokenDatabaseView[str, _TokenMapping]):
227    def __getitem__(self, domain: str) -> _TokenMapping:
228        """Returns the token-to-strings mapping for the specified domain."""
229        return _TokenMapping(self._mapping.get(domain, {}))  # Empty if no match
230
231
232def _add_entry(entries: _TokenToEntries, entry: TokenizedStringEntry) -> None:
233    bisect.insort(
234        entries.setdefault(entry.token, []),
235        entry,
236        key=TokenizedStringEntry.key,  # Keep lists of entries sorted by key.
237    )
238
239
240class Database:
241    """Database of tokenized strings stored as TokenizedStringEntry objects."""
242
243    def __init__(self, entries: Iterable[TokenizedStringEntry] = ()):
244        """Creates a token database."""
245        # The database dict stores each unique (token, string) entry.
246        self._database: dict[_EntryKey, TokenizedStringEntry] = {}
247
248        # Index by token and domain
249        self._token_entries: _TokenToEntries = {}
250        self._domain_token_entries: dict[str, _TokenToEntries] = {}
251
252        self.add(entries)
253
254    @classmethod
255    def from_strings(
256        cls,
257        strings: Iterable[str],
258        domain: str = DEFAULT_DOMAIN,
259        tokenize: Callable[[str], int] = pw_tokenizer_65599_hash,
260    ) -> Database:
261        """Creates a Database from an iterable of strings."""
262        return cls(
263            TokenizedStringEntry(tokenize(string), string, domain)
264            for string in strings
265        )
266
267    @classmethod
268    def merged(cls, *databases: Database) -> Database:
269        """Creates a TokenDatabase from one or more other databases."""
270        db = cls()
271        db.merge(*databases)
272        return db
273
274    @property
275    def token_to_entries(self) -> Mapping[int, Sequence[TokenizedStringEntry]]:
276        """Returns a mapping of tokens to a sequence of TokenizedStringEntry.
277
278        Returns token database entries from all domains.
279        """
280        return _TokenMapping(self._token_entries)
281
282    @property
283    def domains(
284        self,
285    ) -> Mapping[str, Mapping[int, Sequence[TokenizedStringEntry]]]:
286        """Returns a mapping of domains to tokens to a sequence of entries.
287
288        `database.domains[domain][token]` returns a sequence of strings matching
289        the token in the domain, or an empty sequence if there are no matches.
290        """
291        return _DomainTokenMapping(self._domain_token_entries)
292
293    def entries(self) -> ValuesView[TokenizedStringEntry]:
294        """Returns iterable over all TokenizedStringEntries in the database."""
295        return self._database.values()
296
297    def collisions(
298        self,
299    ) -> Iterator[tuple[int, Sequence[TokenizedStringEntry]]]:
300        """Returns tuple of (token, entries_list)) for all colliding tokens."""
301        for token, entries in self.token_to_entries.items():
302            if len(entries) > 1:
303                yield token, entries
304
305    def mark_removed(
306        self,
307        all_entries: Iterable[TokenizedStringEntry],
308        removal_date: datetime | None = None,
309    ) -> list[TokenizedStringEntry]:
310        """Marks entries missing from all_entries as having been removed.
311
312        The entries are assumed to represent the complete set of entries for the
313        database. Entries currently in the database not present in the provided
314        entries are marked with a removal date but remain in the database.
315        Entries in all_entries missing from the database are NOT added; call the
316        add function to add these.
317
318        Args:
319          all_entries: the complete set of strings present in the database
320          removal_date: the datetime for removed entries; today by default
321
322        Returns:
323          A list of entries marked as removed.
324        """
325        if removal_date is None:
326            removal_date = datetime.now()
327
328        all_keys = frozenset(entry.key() for entry in all_entries)
329
330        removed = []
331
332        for entry in self._database.values():
333            if entry.key() not in all_keys and (
334                entry.date_removed is None or removal_date < entry.date_removed
335            ):
336                # Add a removal date, or update it to the oldest date.
337                entry.date_removed = removal_date
338                removed.append(entry)
339
340        return removed
341
342    def add(self, entries: Iterable[TokenizedStringEntry]) -> None:
343        """Adds new entries and updates date_removed for existing entries.
344
345        If the added tokens have removal dates, the newest date is used.
346        """
347        for new_entry in entries:
348            # Update an existing entry or create a new one.
349            try:
350                entry = self._database[new_entry.key()]
351
352                # Keep the latest removal date between the two entries.
353                if new_entry.date_removed is None:
354                    entry.date_removed = None
355                elif (
356                    entry.date_removed
357                    and entry.date_removed < new_entry.date_removed
358                ):
359                    entry.date_removed = new_entry.date_removed
360            except KeyError:
361                self._add_new_entry(new_entry)
362
363    def purge(
364        self, date_removed_cutoff: datetime | None = None
365    ) -> list[TokenizedStringEntry]:
366        """Removes and returns entries removed on/before date_removed_cutoff."""
367        if date_removed_cutoff is None:
368            date_removed_cutoff = datetime.max
369
370        to_delete = [
371            entry
372            for entry in self._database.values()
373            if entry.date_removed and entry.date_removed <= date_removed_cutoff
374        ]
375
376        for entry in to_delete:
377            self._delete_entry(entry)
378
379        return to_delete
380
381    def merge(self, *databases: Database) -> None:
382        """Merges two or more databases together, keeping the newest dates."""
383        for other_db in databases:
384            for entry in other_db.entries():
385                key = entry.key()
386
387                if key in self._database:
388                    self._database[key].update_date_removed(entry.date_removed)
389                else:
390                    self._add_new_entry(entry)
391
392    def filter(
393        self,
394        include: Iterable[str | Pattern[str]] = (),
395        exclude: Iterable[str | Pattern[str]] = (),
396        replace: Iterable[tuple[str | Pattern[str], str]] = (),
397    ) -> None:
398        """Filters the database using regular expressions (strings or compiled).
399
400        Args:
401          include: regexes; only entries matching at least one are kept
402          exclude: regexes; entries matching any of these are removed
403          replace: (regex, str) tuples; replaces matching terms in all entries
404        """
405        to_delete: list[TokenizedStringEntry] = []
406
407        if include:
408            include_re = [re.compile(pattern) for pattern in include]
409            to_delete.extend(
410                val
411                for val in self._database.values()
412                if not any(rgx.search(val.string) for rgx in include_re)
413            )
414
415        if exclude:
416            exclude_re = [re.compile(pattern) for pattern in exclude]
417            to_delete.extend(
418                val
419                for val in self._database.values()
420                if any(rgx.search(val.string) for rgx in exclude_re)
421            )
422
423        for entry in to_delete:
424            self._delete_entry(entry)
425
426        # Do the replacement after removing entries.
427        for search, replacement in replace:
428            search = re.compile(search)
429
430            to_replace: list[TokenizedStringEntry] = []
431            add: list[TokenizedStringEntry] = []
432
433            for entry in self._database.values():
434                new_string = search.sub(replacement, entry.string)
435                if new_string != entry.string:
436                    to_replace.append(entry)
437                    add.append(
438                        TokenizedStringEntry(
439                            entry.token,
440                            new_string,
441                            entry.domain,
442                            entry.date_removed,
443                        )
444                    )
445
446            for entry in to_replace:
447                self._delete_entry(entry)
448            self.add(add)
449
450    def difference(self, other: Database) -> Database:
451        """Returns a new Database with entries in this DB not in the other."""
452        # pylint: disable=protected-access
453        return Database(
454            e for k, e in self._database.items() if k not in other._database
455        )
456        # pylint: enable=protected-access
457
458    def _add_new_entry(self, new_entry: TokenizedStringEntry) -> None:
459        entry = TokenizedStringEntry(  # These are mutable, so make a copy.
460            new_entry.token,
461            new_entry.string,
462            new_entry.domain,
463            new_entry.date_removed,
464        )
465        self._database[entry.key()] = entry
466        _add_entry(self._token_entries, entry)
467        _add_entry(
468            self._domain_token_entries.setdefault(entry.domain, {}), entry
469        )
470
471    def _delete_entry(self, entry: TokenizedStringEntry) -> None:
472        del self._database[entry.key()]
473
474        # Remove from the token / domain mappings and clean up empty lists.
475        self._token_entries[entry.token].remove(entry)
476        if not self._token_entries[entry.token]:
477            del self._token_entries[entry.token]
478
479        self._domain_token_entries[entry.domain][entry.token].remove(entry)
480        if not self._domain_token_entries[entry.domain][entry.token]:
481            del self._domain_token_entries[entry.domain][entry.token]
482            if not self._domain_token_entries[entry.domain]:
483                del self._domain_token_entries[entry.domain]
484
485    def __len__(self) -> int:
486        """Returns the number of entries in the database."""
487        return len(self.entries())
488
489    def __bool__(self) -> bool:
490        """True if the database is non-empty."""
491        return bool(self._database)
492
493    def __str__(self) -> str:
494        """Outputs the database as CSV."""
495        csv_output = io.BytesIO()
496        write_csv(self, csv_output)
497        return csv_output.getvalue().decode()
498
499
500def parse_csv(fd: TextIO) -> Iterable[TokenizedStringEntry]:
501    """Parses TokenizedStringEntries from a CSV token database file."""
502    for line in csv.reader(fd):
503        try:
504            try:
505                token_str, date_str, domain, string_literal = line
506            except ValueError:
507                # If there are only three columns, use the default domain.
508                token_str, date_str, string_literal = line
509                domain = DEFAULT_DOMAIN
510
511            token = int(token_str, 16)
512            date = (
513                datetime.fromisoformat(date_str) if date_str.strip() else None
514            )
515
516            yield TokenizedStringEntry(token, string_literal, domain, date)
517        except (ValueError, UnicodeDecodeError) as err:
518            _LOG.error(
519                'Failed to parse tokenized string entry %s: %s', line, err
520            )
521
522
523def write_csv(database: Database, fd: IO[bytes]) -> None:
524    """Writes the database as CSV to the provided binary file."""
525    for entry in sorted(database.entries()):
526        _write_csv_line(fd, entry)
527
528
529def _write_csv_line(fd: IO[bytes], entry: TokenizedStringEntry):
530    """Write a line in CSV format to the provided binary file."""
531    # Align the CSV output to 10-character columns for improved readability.
532    # Use \n instead of RFC 4180's \r\n.
533    fd.write(
534        '{:08x},{:10},"{}","{}"\n'.format(
535            entry.token,
536            entry.date_removed.date().isoformat() if entry.date_removed else '',
537            entry.domain.replace('"', '""'),  # escape " as ""
538            entry.string.replace('"', '""'),
539        ).encode()
540    )
541
542
543class _BinaryFileFormat(NamedTuple):
544    """Attributes of the binary token database file format."""
545
546    magic: bytes = b'TOKENS\0\0'
547    header: struct.Struct = struct.Struct('<8sI4x')
548    entry: struct.Struct = struct.Struct('<IBBH')
549
550
551BINARY_FORMAT = _BinaryFileFormat()
552
553
554class DatabaseFormatError(Exception):
555    """Failed to parse a token database file."""
556
557
558def file_is_binary_database(fd: BinaryIO) -> bool:
559    """True if the file starts with the binary token database magic string."""
560    try:
561        fd.seek(0)
562        magic = fd.read(len(BINARY_FORMAT.magic))
563        fd.seek(0)
564        return BINARY_FORMAT.magic == magic
565    except IOError:
566        return False
567
568
569def _check_that_file_is_csv_database(path: Path) -> None:
570    """Raises an error unless the path appears to be a CSV token database."""
571    try:
572        with path.open('rb') as fd:
573            data = fd.read(8)  # Read 8 bytes, which should be the first token.
574
575        if not data:
576            return  # File is empty, which is valid CSV.
577
578        if len(data) != 8:
579            raise DatabaseFormatError(
580                f'Attempted to read {path} as a CSV token database, but the '
581                f'file is too short ({len(data)} B)'
582            )
583
584        # Make sure the first 8 chars are a valid hexadecimal number.
585        _ = int(data.decode(), 16)
586    except (IOError, UnicodeDecodeError, ValueError) as err:
587        raise DatabaseFormatError(
588            f'Encountered error while reading {path} as a CSV token database'
589        ) from err
590
591
592def parse_binary(fd: BinaryIO) -> Iterable[TokenizedStringEntry]:
593    """Parses TokenizedStringEntries from a binary token database file."""
594    magic, entry_count = BINARY_FORMAT.header.unpack(
595        fd.read(BINARY_FORMAT.header.size)
596    )
597
598    if magic != BINARY_FORMAT.magic:
599        raise DatabaseFormatError(
600            f'Binary token database magic number mismatch (found {magic!r}, '
601            f'expected {BINARY_FORMAT.magic!r}) while reading from {fd}'
602        )
603
604    entries = []
605
606    for _ in range(entry_count):
607        token, day, month, year = BINARY_FORMAT.entry.unpack(
608            fd.read(BINARY_FORMAT.entry.size)
609        )
610
611        try:
612            date_removed: datetime | None = datetime(year, month, day)
613        except ValueError:
614            date_removed = None
615
616        entries.append((token, date_removed))
617
618    # Read the entire string table and define a function for looking up strings.
619    string_table = fd.read()
620
621    def read_string(start):
622        end = string_table.find(b'\0', start)
623        return (
624            string_table[start : string_table.find(b'\0', start)].decode(),
625            end + 1,
626        )
627
628    offset = 0
629    for token, removed in entries:
630        string, offset = read_string(offset)
631        yield TokenizedStringEntry(token, string, DEFAULT_DOMAIN, removed)
632
633
634def write_binary(database: Database, fd: BinaryIO) -> None:
635    """Writes the database as packed binary to the provided binary file."""
636    entries = sorted(database.entries())
637
638    fd.write(BINARY_FORMAT.header.pack(BINARY_FORMAT.magic, len(entries)))
639
640    string_table = bytearray()
641
642    for entry in entries:
643        if entry.date_removed:
644            removed_day = entry.date_removed.day
645            removed_month = entry.date_removed.month
646            removed_year = entry.date_removed.year
647        else:
648            # If there is no removal date, use the special value 0xffffffff for
649            # the day/month/year. That ensures that still-present tokens appear
650            # as the newest tokens when sorted by removal date.
651            removed_day = 0xFF
652            removed_month = 0xFF
653            removed_year = 0xFFFF
654
655        string_table += entry.string.encode()
656        string_table.append(0)
657
658        fd.write(
659            BINARY_FORMAT.entry.pack(
660                entry.token, removed_day, removed_month, removed_year
661            )
662        )
663
664    fd.write(string_table)
665
666
667class DatabaseFile(Database):
668    """A token database that is associated with a particular file.
669
670    This class adds the write_to_file() method that writes to file from which it
671    was created in the correct format (CSV or binary).
672    """
673
674    def __init__(
675        self, path: Path, entries: Iterable[TokenizedStringEntry]
676    ) -> None:
677        super().__init__(entries)
678        self.path = path
679
680    @staticmethod
681    def load(path: Path) -> DatabaseFile:
682        """Creates a DatabaseFile that coincides to the file type."""
683        if path.is_dir():
684            return _DirectoryDatabase(path)
685
686        # Read the path as a packed binary file.
687        with path.open('rb') as fd:
688            if file_is_binary_database(fd):
689                return _BinaryDatabase(path, fd)
690
691        # Read the path as a CSV file.
692        _check_that_file_is_csv_database(path)
693        return _CSVDatabase(path)
694
695    @abstractmethod
696    def write_to_file(self, *, rewrite: bool = False) -> None:
697        """Exports in the original format to the original path."""
698
699    @abstractmethod
700    def add_and_discard_temporary(
701        self, entries: Iterable[TokenizedStringEntry], commit: str
702    ) -> None:
703        """Discards and adds entries to export in the original format.
704
705        Adds entries after removing temporary entries from the Database
706        to exclusively write re-occurring entries into memory and disk.
707        """
708
709
710class _BinaryDatabase(DatabaseFile):
711    def __init__(self, path: Path, fd: BinaryIO) -> None:
712        super().__init__(path, parse_binary(fd))
713
714    def write_to_file(self, *, rewrite: bool = False) -> None:
715        """Exports in the binary format to the original path."""
716        del rewrite  # Binary databases are always rewritten
717        with self.path.open('wb') as fd:
718            write_binary(self, fd)
719
720    def add_and_discard_temporary(
721        self, entries: Iterable[TokenizedStringEntry], commit: str
722    ) -> None:
723        # TODO: b/241471465 - Implement adding new tokens and removing
724        # temporary entries for binary databases.
725        raise NotImplementedError(
726            '--discard-temporary is currently only '
727            'supported for directory databases'
728        )
729
730
731class _CSVDatabase(DatabaseFile):
732    def __init__(self, path: Path) -> None:
733        with path.open('r', newline='', encoding='utf-8') as csv_fd:
734            super().__init__(path, parse_csv(csv_fd))
735
736    def write_to_file(self, *, rewrite: bool = False) -> None:
737        """Exports in the CSV format to the original path."""
738        del rewrite  # CSV databases are always rewritten
739        with self.path.open('wb') as fd:
740            write_csv(self, fd)
741
742    def add_and_discard_temporary(
743        self, entries: Iterable[TokenizedStringEntry], commit: str
744    ) -> None:
745        # TODO: b/241471465 - Implement adding new tokens and removing
746        # temporary entries for CSV databases.
747        raise NotImplementedError(
748            '--discard-temporary is currently only '
749            'supported for directory databases'
750        )
751
752
753# The suffix used for CSV files in a directory database.
754DIR_DB_SUFFIX = '.pw_tokenizer.csv'
755DIR_DB_GLOB = '*' + DIR_DB_SUFFIX
756
757
758def _parse_directory(directory: Path) -> Iterable[TokenizedStringEntry]:
759    """Parses TokenizedStringEntries tokenizer CSV files in the directory."""
760    for path in directory.glob(DIR_DB_GLOB):
761        yield from _CSVDatabase(path).entries()
762
763
764def _most_recently_modified_file(paths: Iterable[Path]) -> Path:
765    return max(paths, key=lambda path: path.stat().st_mtime)
766
767
768class _DirectoryDatabase(DatabaseFile):
769    def __init__(self, directory: Path) -> None:
770        super().__init__(directory, _parse_directory(directory))
771
772    def write_to_file(self, *, rewrite: bool = False) -> None:
773        """Creates a new CSV file in the directory with any new tokens."""
774        if rewrite:
775            # Write the entire database to a new CSV file
776            new_file = self._create_filename()
777            with new_file.open('wb') as fd:
778                write_csv(self, fd)
779
780            # Delete all CSV files except for the new CSV with everything.
781            for csv_file in self.path.glob(DIR_DB_GLOB):
782                if csv_file != new_file:
783                    csv_file.unlink()
784        else:
785            # Reread the tokens from disk and write only the new entries to CSV.
786            current_tokens = Database(_parse_directory(self.path))
787            new_entries = self.difference(current_tokens)
788            if new_entries:
789                with self._create_filename().open('wb') as fd:
790                    write_csv(new_entries, fd)
791
792    def _git_paths(self, commands: list) -> list[Path]:
793        """Returns a list of database CSVs from a Git command."""
794        try:
795            output = subprocess.run(
796                ['git', *commands, DIR_DB_GLOB],
797                capture_output=True,
798                check=True,
799                cwd=self.path,
800                text=True,
801            ).stdout.strip()
802            return [self.path / repo_path for repo_path in output.splitlines()]
803        except subprocess.CalledProcessError:
804            return []
805
806    def _find_latest_csv(self, commit: str) -> Path:
807        """Finds or creates a CSV to which to write new entries.
808
809        - Check for untracked CSVs. Use the most recently modified file, if any.
810        - Check for CSVs added in HEAD, if HEAD is not an ancestor of commit.
811          Use the most recently modified file, if any.
812        - If no untracked or committed files were found, create a new file.
813        """
814
815        # Prioritize untracked files in the directory database.
816        untracked_changes = self._git_paths(
817            ['ls-files', '--others', '--exclude-standard']
818        )
819        if untracked_changes:
820            return _most_recently_modified_file(untracked_changes)
821
822        # Check if HEAD is an ancestor of the base commit. This checks whether
823        # the top commit has been merged or not. If it has been merged, create a
824        # new CSV to use. Otherwise, check if a CSV was added in the commit.
825        head_is_not_merged = (
826            subprocess.run(
827                ['git', 'merge-base', '--is-ancestor', 'HEAD', commit],
828                cwd=self.path,
829                stdout=subprocess.DEVNULL,
830                stderr=subprocess.DEVNULL,
831            ).returncode
832            != 0
833        )
834
835        if head_is_not_merged:
836            # Find CSVs added in the top commit.
837            csvs_from_top_commit = self._git_paths(
838                [
839                    'diff',
840                    '--name-only',
841                    '--diff-filter=A',
842                    '--relative',
843                    'HEAD~',
844                ]
845            )
846
847            if csvs_from_top_commit:
848                return _most_recently_modified_file(csvs_from_top_commit)
849
850        return self._create_filename()
851
852    def _create_filename(self) -> Path:
853        """Generates a unique filename not in the directory."""
854        # Tracked and untracked files do not exist in the repo.
855        while (file := self.path / f'{uuid4().hex}{DIR_DB_SUFFIX}').exists():
856            pass
857        return file
858
859    def add_and_discard_temporary(
860        self, entries: Iterable[TokenizedStringEntry], commit: str
861    ) -> None:
862        """Adds new entries and discards temporary entries on disk.
863
864        - Find the latest CSV in the directory database or create a new one.
865        - Delete entries in the latest CSV that are not in the entries passed to
866          this function.
867        - Add the new entries to this database.
868        - Overwrite the latest CSV with only the newly added entries.
869        """
870        # Find entries not currently in the database.
871        added = Database(entries)
872        new_entries = added.difference(self)
873
874        csv_path = self._find_latest_csv(commit)
875        if csv_path.exists():
876            # Loading the CSV as a DatabaseFile.
877            csv_db = DatabaseFile.load(csv_path)
878
879            # Delete entries added in the CSV, but not added in this function.
880            for key in (e.key() for e in csv_db.difference(added).entries()):
881                del self._database[key]
882                del csv_db._database[key]  # pylint: disable=protected-access
883
884            csv_db.add(new_entries.entries())
885            csv_db.write_to_file()
886        elif new_entries:  # If the CSV does not exist, write all new tokens.
887            with csv_path.open('wb') as fd:
888                write_csv(new_entries, fd)
889
890        self.add(new_entries.entries())
891