1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Builds and manages databases of tokenized strings.""" 15 16from __future__ import annotations 17 18from abc import abstractmethod 19import bisect 20from collections.abc import ( 21 Callable, 22 Iterable, 23 Iterator, 24 Mapping, 25 Sequence, 26 ValuesView, 27) 28import csv 29from dataclasses import dataclass 30from datetime import datetime 31import io 32import logging 33from pathlib import Path 34import re 35import struct 36import subprocess 37from typing import ( 38 Any, 39 BinaryIO, 40 IO, 41 NamedTuple, 42 overload, 43 Pattern, 44 TextIO, 45 TypeVar, 46) 47from uuid import uuid4 48 49DEFAULT_DOMAIN = '' 50 51# The default hash length to use for C-style hashes. This value only applies 52# when manually hashing strings to recreate token calculations in C. The C++ 53# hash function does not have a maximum length. 54# 55# This MUST match the default value of PW_TOKENIZER_CFG_C_HASH_LENGTH in 56# pw_tokenizer/public/pw_tokenizer/config.h. 57DEFAULT_C_HASH_LENGTH = 128 58 59TOKENIZER_HASH_CONSTANT = 65599 60 61_LOG = logging.getLogger('pw_tokenizer') 62 63 64def _value(char: int | str) -> int: 65 return char if isinstance(char, int) else ord(char) 66 67 68def pw_tokenizer_65599_hash( 69 string: str | bytes, *, hash_length: int | None = None 70) -> int: 71 """Hashes the string with the hash function used to generate tokens in C++. 72 73 This hash function is used calculate tokens from strings in Python. It is 74 not used when extracting tokens from an ELF, since the token is stored in 75 the ELF as part of tokenization. 76 """ 77 hash_value = len(string) 78 coefficient = TOKENIZER_HASH_CONSTANT 79 80 for char in string[:hash_length]: 81 hash_value = (hash_value + coefficient * _value(char)) % 2**32 82 coefficient = (coefficient * TOKENIZER_HASH_CONSTANT) % 2**32 83 84 return hash_value 85 86 87def c_hash( 88 string: str | bytes, hash_length: int = DEFAULT_C_HASH_LENGTH 89) -> int: 90 """Hashes the string with the hash function used in C.""" 91 return pw_tokenizer_65599_hash(string, hash_length=hash_length) 92 93 94@dataclass(frozen=True, eq=True, order=True) 95class _EntryKey: 96 """Uniquely refers to an entry.""" 97 98 domain: str 99 token: int 100 string: str 101 102 103class TokenizedStringEntry: 104 """A tokenized string with its metadata.""" 105 106 def __init__( 107 self, 108 token: int, 109 string: str, 110 domain: str = DEFAULT_DOMAIN, 111 date_removed: datetime | None = None, 112 ) -> None: 113 self._key = _EntryKey( 114 ''.join(domain.split()), 115 token, 116 string, 117 ) 118 self.date_removed = date_removed 119 120 @property 121 def token(self) -> int: 122 return self._key.token 123 124 @property 125 def string(self) -> str: 126 return self._key.string 127 128 @property 129 def domain(self) -> str: 130 return self._key.domain 131 132 def key(self) -> _EntryKey: 133 """The key determines uniqueness for a tokenized string.""" 134 return self._key 135 136 def update_date_removed(self, new_date_removed: datetime | None) -> None: 137 """Sets self.date_removed if the other date is newer.""" 138 # No removal date (None) is treated as the newest date. 139 if self.date_removed is None: 140 return 141 142 if new_date_removed is None or new_date_removed > self.date_removed: 143 self.date_removed = new_date_removed 144 145 def __eq__(self, other: Any) -> bool: 146 return ( 147 self.key() == other.key() 148 and self.date_removed == other.date_removed 149 ) 150 151 def __lt__(self, other: Any) -> bool: 152 """Sorts the entry by domain, token, date removed, then string.""" 153 if self.domain != other.domain: 154 return self.domain < other.domain 155 156 if self.token != other.token: 157 return self.token < other.token 158 159 # Sort removal dates in reverse, so the most recently removed (or still 160 # present) entry appears first. 161 if self.date_removed != other.date_removed: 162 return (other.date_removed or datetime.max) < ( 163 self.date_removed or datetime.max 164 ) 165 166 return self.string < other.string 167 168 def __str__(self) -> str: 169 return self.string 170 171 def __repr__(self) -> str: 172 return ( 173 f'{self.__class__.__name__}(token=0x{self.token:08x}, ' 174 f'string={self.string!r}, domain={self.domain!r})' 175 ) 176 177 178_TokenToEntries = dict[int, list[TokenizedStringEntry]] 179_K = TypeVar('_K') 180_V = TypeVar('_V') 181_T = TypeVar('_T') 182 183 184class _TokenDatabaseView(Mapping[_K, _V]): # pylint: disable=abstract-method 185 """Read-only mapping view of a token database. 186 187 Behaves like a read-only version of defaultdict(list). 188 """ 189 190 def __init__(self, mapping: Mapping[_K, Any]) -> None: 191 self._mapping = mapping 192 193 def __contains__(self, key: object) -> bool: 194 return key in self._mapping 195 196 @overload 197 def get(self, key: _K) -> _V | None: # pylint: disable=arguments-differ 198 ... 199 200 @overload 201 def get(self, key: _K, default: _T) -> _V | _T: # pylint: disable=W0222 202 ... 203 204 def get(self, key: _K, default: _T | None = None) -> _V | _T | None: 205 return self._mapping.get(key, default) 206 207 def __iter__(self) -> Iterator[_K]: 208 return iter(self._mapping) 209 210 def __len__(self) -> int: 211 return len(self._mapping) 212 213 def __str__(self) -> str: 214 return str(self._mapping) 215 216 def __repr__(self) -> str: 217 return repr(self._mapping) 218 219 220class _TokenMapping(_TokenDatabaseView[int, Sequence[TokenizedStringEntry]]): 221 def __getitem__(self, token: int) -> Sequence[TokenizedStringEntry]: 222 """Returns strings that match the specified token; may be empty.""" 223 return self._mapping.get(token, ()) # Empty sequence if no match 224 225 226class _DomainTokenMapping(_TokenDatabaseView[str, _TokenMapping]): 227 def __getitem__(self, domain: str) -> _TokenMapping: 228 """Returns the token-to-strings mapping for the specified domain.""" 229 return _TokenMapping(self._mapping.get(domain, {})) # Empty if no match 230 231 232def _add_entry(entries: _TokenToEntries, entry: TokenizedStringEntry) -> None: 233 bisect.insort( 234 entries.setdefault(entry.token, []), 235 entry, 236 key=TokenizedStringEntry.key, # Keep lists of entries sorted by key. 237 ) 238 239 240class Database: 241 """Database of tokenized strings stored as TokenizedStringEntry objects.""" 242 243 def __init__(self, entries: Iterable[TokenizedStringEntry] = ()): 244 """Creates a token database.""" 245 # The database dict stores each unique (token, string) entry. 246 self._database: dict[_EntryKey, TokenizedStringEntry] = {} 247 248 # Index by token and domain 249 self._token_entries: _TokenToEntries = {} 250 self._domain_token_entries: dict[str, _TokenToEntries] = {} 251 252 self.add(entries) 253 254 @classmethod 255 def from_strings( 256 cls, 257 strings: Iterable[str], 258 domain: str = DEFAULT_DOMAIN, 259 tokenize: Callable[[str], int] = pw_tokenizer_65599_hash, 260 ) -> Database: 261 """Creates a Database from an iterable of strings.""" 262 return cls( 263 TokenizedStringEntry(tokenize(string), string, domain) 264 for string in strings 265 ) 266 267 @classmethod 268 def merged(cls, *databases: Database) -> Database: 269 """Creates a TokenDatabase from one or more other databases.""" 270 db = cls() 271 db.merge(*databases) 272 return db 273 274 @property 275 def token_to_entries(self) -> Mapping[int, Sequence[TokenizedStringEntry]]: 276 """Returns a mapping of tokens to a sequence of TokenizedStringEntry. 277 278 Returns token database entries from all domains. 279 """ 280 return _TokenMapping(self._token_entries) 281 282 @property 283 def domains( 284 self, 285 ) -> Mapping[str, Mapping[int, Sequence[TokenizedStringEntry]]]: 286 """Returns a mapping of domains to tokens to a sequence of entries. 287 288 `database.domains[domain][token]` returns a sequence of strings matching 289 the token in the domain, or an empty sequence if there are no matches. 290 """ 291 return _DomainTokenMapping(self._domain_token_entries) 292 293 def entries(self) -> ValuesView[TokenizedStringEntry]: 294 """Returns iterable over all TokenizedStringEntries in the database.""" 295 return self._database.values() 296 297 def collisions( 298 self, 299 ) -> Iterator[tuple[int, Sequence[TokenizedStringEntry]]]: 300 """Returns tuple of (token, entries_list)) for all colliding tokens.""" 301 for token, entries in self.token_to_entries.items(): 302 if len(entries) > 1: 303 yield token, entries 304 305 def mark_removed( 306 self, 307 all_entries: Iterable[TokenizedStringEntry], 308 removal_date: datetime | None = None, 309 ) -> list[TokenizedStringEntry]: 310 """Marks entries missing from all_entries as having been removed. 311 312 The entries are assumed to represent the complete set of entries for the 313 database. Entries currently in the database not present in the provided 314 entries are marked with a removal date but remain in the database. 315 Entries in all_entries missing from the database are NOT added; call the 316 add function to add these. 317 318 Args: 319 all_entries: the complete set of strings present in the database 320 removal_date: the datetime for removed entries; today by default 321 322 Returns: 323 A list of entries marked as removed. 324 """ 325 if removal_date is None: 326 removal_date = datetime.now() 327 328 all_keys = frozenset(entry.key() for entry in all_entries) 329 330 removed = [] 331 332 for entry in self._database.values(): 333 if entry.key() not in all_keys and ( 334 entry.date_removed is None or removal_date < entry.date_removed 335 ): 336 # Add a removal date, or update it to the oldest date. 337 entry.date_removed = removal_date 338 removed.append(entry) 339 340 return removed 341 342 def add(self, entries: Iterable[TokenizedStringEntry]) -> None: 343 """Adds new entries and updates date_removed for existing entries. 344 345 If the added tokens have removal dates, the newest date is used. 346 """ 347 for new_entry in entries: 348 # Update an existing entry or create a new one. 349 try: 350 entry = self._database[new_entry.key()] 351 352 # Keep the latest removal date between the two entries. 353 if new_entry.date_removed is None: 354 entry.date_removed = None 355 elif ( 356 entry.date_removed 357 and entry.date_removed < new_entry.date_removed 358 ): 359 entry.date_removed = new_entry.date_removed 360 except KeyError: 361 self._add_new_entry(new_entry) 362 363 def purge( 364 self, date_removed_cutoff: datetime | None = None 365 ) -> list[TokenizedStringEntry]: 366 """Removes and returns entries removed on/before date_removed_cutoff.""" 367 if date_removed_cutoff is None: 368 date_removed_cutoff = datetime.max 369 370 to_delete = [ 371 entry 372 for entry in self._database.values() 373 if entry.date_removed and entry.date_removed <= date_removed_cutoff 374 ] 375 376 for entry in to_delete: 377 self._delete_entry(entry) 378 379 return to_delete 380 381 def merge(self, *databases: Database) -> None: 382 """Merges two or more databases together, keeping the newest dates.""" 383 for other_db in databases: 384 for entry in other_db.entries(): 385 key = entry.key() 386 387 if key in self._database: 388 self._database[key].update_date_removed(entry.date_removed) 389 else: 390 self._add_new_entry(entry) 391 392 def filter( 393 self, 394 include: Iterable[str | Pattern[str]] = (), 395 exclude: Iterable[str | Pattern[str]] = (), 396 replace: Iterable[tuple[str | Pattern[str], str]] = (), 397 ) -> None: 398 """Filters the database using regular expressions (strings or compiled). 399 400 Args: 401 include: regexes; only entries matching at least one are kept 402 exclude: regexes; entries matching any of these are removed 403 replace: (regex, str) tuples; replaces matching terms in all entries 404 """ 405 to_delete: list[TokenizedStringEntry] = [] 406 407 if include: 408 include_re = [re.compile(pattern) for pattern in include] 409 to_delete.extend( 410 val 411 for val in self._database.values() 412 if not any(rgx.search(val.string) for rgx in include_re) 413 ) 414 415 if exclude: 416 exclude_re = [re.compile(pattern) for pattern in exclude] 417 to_delete.extend( 418 val 419 for val in self._database.values() 420 if any(rgx.search(val.string) for rgx in exclude_re) 421 ) 422 423 for entry in to_delete: 424 self._delete_entry(entry) 425 426 # Do the replacement after removing entries. 427 for search, replacement in replace: 428 search = re.compile(search) 429 430 to_replace: list[TokenizedStringEntry] = [] 431 add: list[TokenizedStringEntry] = [] 432 433 for entry in self._database.values(): 434 new_string = search.sub(replacement, entry.string) 435 if new_string != entry.string: 436 to_replace.append(entry) 437 add.append( 438 TokenizedStringEntry( 439 entry.token, 440 new_string, 441 entry.domain, 442 entry.date_removed, 443 ) 444 ) 445 446 for entry in to_replace: 447 self._delete_entry(entry) 448 self.add(add) 449 450 def difference(self, other: Database) -> Database: 451 """Returns a new Database with entries in this DB not in the other.""" 452 # pylint: disable=protected-access 453 return Database( 454 e for k, e in self._database.items() if k not in other._database 455 ) 456 # pylint: enable=protected-access 457 458 def _add_new_entry(self, new_entry: TokenizedStringEntry) -> None: 459 entry = TokenizedStringEntry( # These are mutable, so make a copy. 460 new_entry.token, 461 new_entry.string, 462 new_entry.domain, 463 new_entry.date_removed, 464 ) 465 self._database[entry.key()] = entry 466 _add_entry(self._token_entries, entry) 467 _add_entry( 468 self._domain_token_entries.setdefault(entry.domain, {}), entry 469 ) 470 471 def _delete_entry(self, entry: TokenizedStringEntry) -> None: 472 del self._database[entry.key()] 473 474 # Remove from the token / domain mappings and clean up empty lists. 475 self._token_entries[entry.token].remove(entry) 476 if not self._token_entries[entry.token]: 477 del self._token_entries[entry.token] 478 479 self._domain_token_entries[entry.domain][entry.token].remove(entry) 480 if not self._domain_token_entries[entry.domain][entry.token]: 481 del self._domain_token_entries[entry.domain][entry.token] 482 if not self._domain_token_entries[entry.domain]: 483 del self._domain_token_entries[entry.domain] 484 485 def __len__(self) -> int: 486 """Returns the number of entries in the database.""" 487 return len(self.entries()) 488 489 def __bool__(self) -> bool: 490 """True if the database is non-empty.""" 491 return bool(self._database) 492 493 def __str__(self) -> str: 494 """Outputs the database as CSV.""" 495 csv_output = io.BytesIO() 496 write_csv(self, csv_output) 497 return csv_output.getvalue().decode() 498 499 500def parse_csv(fd: TextIO) -> Iterable[TokenizedStringEntry]: 501 """Parses TokenizedStringEntries from a CSV token database file.""" 502 for line in csv.reader(fd): 503 try: 504 try: 505 token_str, date_str, domain, string_literal = line 506 except ValueError: 507 # If there are only three columns, use the default domain. 508 token_str, date_str, string_literal = line 509 domain = DEFAULT_DOMAIN 510 511 token = int(token_str, 16) 512 date = ( 513 datetime.fromisoformat(date_str) if date_str.strip() else None 514 ) 515 516 yield TokenizedStringEntry(token, string_literal, domain, date) 517 except (ValueError, UnicodeDecodeError) as err: 518 _LOG.error( 519 'Failed to parse tokenized string entry %s: %s', line, err 520 ) 521 522 523def write_csv(database: Database, fd: IO[bytes]) -> None: 524 """Writes the database as CSV to the provided binary file.""" 525 for entry in sorted(database.entries()): 526 _write_csv_line(fd, entry) 527 528 529def _write_csv_line(fd: IO[bytes], entry: TokenizedStringEntry): 530 """Write a line in CSV format to the provided binary file.""" 531 # Align the CSV output to 10-character columns for improved readability. 532 # Use \n instead of RFC 4180's \r\n. 533 fd.write( 534 '{:08x},{:10},"{}","{}"\n'.format( 535 entry.token, 536 entry.date_removed.date().isoformat() if entry.date_removed else '', 537 entry.domain.replace('"', '""'), # escape " as "" 538 entry.string.replace('"', '""'), 539 ).encode() 540 ) 541 542 543class _BinaryFileFormat(NamedTuple): 544 """Attributes of the binary token database file format.""" 545 546 magic: bytes = b'TOKENS\0\0' 547 header: struct.Struct = struct.Struct('<8sI4x') 548 entry: struct.Struct = struct.Struct('<IBBH') 549 550 551BINARY_FORMAT = _BinaryFileFormat() 552 553 554class DatabaseFormatError(Exception): 555 """Failed to parse a token database file.""" 556 557 558def file_is_binary_database(fd: BinaryIO) -> bool: 559 """True if the file starts with the binary token database magic string.""" 560 try: 561 fd.seek(0) 562 magic = fd.read(len(BINARY_FORMAT.magic)) 563 fd.seek(0) 564 return BINARY_FORMAT.magic == magic 565 except IOError: 566 return False 567 568 569def _check_that_file_is_csv_database(path: Path) -> None: 570 """Raises an error unless the path appears to be a CSV token database.""" 571 try: 572 with path.open('rb') as fd: 573 data = fd.read(8) # Read 8 bytes, which should be the first token. 574 575 if not data: 576 return # File is empty, which is valid CSV. 577 578 if len(data) != 8: 579 raise DatabaseFormatError( 580 f'Attempted to read {path} as a CSV token database, but the ' 581 f'file is too short ({len(data)} B)' 582 ) 583 584 # Make sure the first 8 chars are a valid hexadecimal number. 585 _ = int(data.decode(), 16) 586 except (IOError, UnicodeDecodeError, ValueError) as err: 587 raise DatabaseFormatError( 588 f'Encountered error while reading {path} as a CSV token database' 589 ) from err 590 591 592def parse_binary(fd: BinaryIO) -> Iterable[TokenizedStringEntry]: 593 """Parses TokenizedStringEntries from a binary token database file.""" 594 magic, entry_count = BINARY_FORMAT.header.unpack( 595 fd.read(BINARY_FORMAT.header.size) 596 ) 597 598 if magic != BINARY_FORMAT.magic: 599 raise DatabaseFormatError( 600 f'Binary token database magic number mismatch (found {magic!r}, ' 601 f'expected {BINARY_FORMAT.magic!r}) while reading from {fd}' 602 ) 603 604 entries = [] 605 606 for _ in range(entry_count): 607 token, day, month, year = BINARY_FORMAT.entry.unpack( 608 fd.read(BINARY_FORMAT.entry.size) 609 ) 610 611 try: 612 date_removed: datetime | None = datetime(year, month, day) 613 except ValueError: 614 date_removed = None 615 616 entries.append((token, date_removed)) 617 618 # Read the entire string table and define a function for looking up strings. 619 string_table = fd.read() 620 621 def read_string(start): 622 end = string_table.find(b'\0', start) 623 return ( 624 string_table[start : string_table.find(b'\0', start)].decode(), 625 end + 1, 626 ) 627 628 offset = 0 629 for token, removed in entries: 630 string, offset = read_string(offset) 631 yield TokenizedStringEntry(token, string, DEFAULT_DOMAIN, removed) 632 633 634def write_binary(database: Database, fd: BinaryIO) -> None: 635 """Writes the database as packed binary to the provided binary file.""" 636 entries = sorted(database.entries()) 637 638 fd.write(BINARY_FORMAT.header.pack(BINARY_FORMAT.magic, len(entries))) 639 640 string_table = bytearray() 641 642 for entry in entries: 643 if entry.date_removed: 644 removed_day = entry.date_removed.day 645 removed_month = entry.date_removed.month 646 removed_year = entry.date_removed.year 647 else: 648 # If there is no removal date, use the special value 0xffffffff for 649 # the day/month/year. That ensures that still-present tokens appear 650 # as the newest tokens when sorted by removal date. 651 removed_day = 0xFF 652 removed_month = 0xFF 653 removed_year = 0xFFFF 654 655 string_table += entry.string.encode() 656 string_table.append(0) 657 658 fd.write( 659 BINARY_FORMAT.entry.pack( 660 entry.token, removed_day, removed_month, removed_year 661 ) 662 ) 663 664 fd.write(string_table) 665 666 667class DatabaseFile(Database): 668 """A token database that is associated with a particular file. 669 670 This class adds the write_to_file() method that writes to file from which it 671 was created in the correct format (CSV or binary). 672 """ 673 674 def __init__( 675 self, path: Path, entries: Iterable[TokenizedStringEntry] 676 ) -> None: 677 super().__init__(entries) 678 self.path = path 679 680 @staticmethod 681 def load(path: Path) -> DatabaseFile: 682 """Creates a DatabaseFile that coincides to the file type.""" 683 if path.is_dir(): 684 return _DirectoryDatabase(path) 685 686 # Read the path as a packed binary file. 687 with path.open('rb') as fd: 688 if file_is_binary_database(fd): 689 return _BinaryDatabase(path, fd) 690 691 # Read the path as a CSV file. 692 _check_that_file_is_csv_database(path) 693 return _CSVDatabase(path) 694 695 @abstractmethod 696 def write_to_file(self, *, rewrite: bool = False) -> None: 697 """Exports in the original format to the original path.""" 698 699 @abstractmethod 700 def add_and_discard_temporary( 701 self, entries: Iterable[TokenizedStringEntry], commit: str 702 ) -> None: 703 """Discards and adds entries to export in the original format. 704 705 Adds entries after removing temporary entries from the Database 706 to exclusively write re-occurring entries into memory and disk. 707 """ 708 709 710class _BinaryDatabase(DatabaseFile): 711 def __init__(self, path: Path, fd: BinaryIO) -> None: 712 super().__init__(path, parse_binary(fd)) 713 714 def write_to_file(self, *, rewrite: bool = False) -> None: 715 """Exports in the binary format to the original path.""" 716 del rewrite # Binary databases are always rewritten 717 with self.path.open('wb') as fd: 718 write_binary(self, fd) 719 720 def add_and_discard_temporary( 721 self, entries: Iterable[TokenizedStringEntry], commit: str 722 ) -> None: 723 # TODO: b/241471465 - Implement adding new tokens and removing 724 # temporary entries for binary databases. 725 raise NotImplementedError( 726 '--discard-temporary is currently only ' 727 'supported for directory databases' 728 ) 729 730 731class _CSVDatabase(DatabaseFile): 732 def __init__(self, path: Path) -> None: 733 with path.open('r', newline='', encoding='utf-8') as csv_fd: 734 super().__init__(path, parse_csv(csv_fd)) 735 736 def write_to_file(self, *, rewrite: bool = False) -> None: 737 """Exports in the CSV format to the original path.""" 738 del rewrite # CSV databases are always rewritten 739 with self.path.open('wb') as fd: 740 write_csv(self, fd) 741 742 def add_and_discard_temporary( 743 self, entries: Iterable[TokenizedStringEntry], commit: str 744 ) -> None: 745 # TODO: b/241471465 - Implement adding new tokens and removing 746 # temporary entries for CSV databases. 747 raise NotImplementedError( 748 '--discard-temporary is currently only ' 749 'supported for directory databases' 750 ) 751 752 753# The suffix used for CSV files in a directory database. 754DIR_DB_SUFFIX = '.pw_tokenizer.csv' 755DIR_DB_GLOB = '*' + DIR_DB_SUFFIX 756 757 758def _parse_directory(directory: Path) -> Iterable[TokenizedStringEntry]: 759 """Parses TokenizedStringEntries tokenizer CSV files in the directory.""" 760 for path in directory.glob(DIR_DB_GLOB): 761 yield from _CSVDatabase(path).entries() 762 763 764def _most_recently_modified_file(paths: Iterable[Path]) -> Path: 765 return max(paths, key=lambda path: path.stat().st_mtime) 766 767 768class _DirectoryDatabase(DatabaseFile): 769 def __init__(self, directory: Path) -> None: 770 super().__init__(directory, _parse_directory(directory)) 771 772 def write_to_file(self, *, rewrite: bool = False) -> None: 773 """Creates a new CSV file in the directory with any new tokens.""" 774 if rewrite: 775 # Write the entire database to a new CSV file 776 new_file = self._create_filename() 777 with new_file.open('wb') as fd: 778 write_csv(self, fd) 779 780 # Delete all CSV files except for the new CSV with everything. 781 for csv_file in self.path.glob(DIR_DB_GLOB): 782 if csv_file != new_file: 783 csv_file.unlink() 784 else: 785 # Reread the tokens from disk and write only the new entries to CSV. 786 current_tokens = Database(_parse_directory(self.path)) 787 new_entries = self.difference(current_tokens) 788 if new_entries: 789 with self._create_filename().open('wb') as fd: 790 write_csv(new_entries, fd) 791 792 def _git_paths(self, commands: list) -> list[Path]: 793 """Returns a list of database CSVs from a Git command.""" 794 try: 795 output = subprocess.run( 796 ['git', *commands, DIR_DB_GLOB], 797 capture_output=True, 798 check=True, 799 cwd=self.path, 800 text=True, 801 ).stdout.strip() 802 return [self.path / repo_path for repo_path in output.splitlines()] 803 except subprocess.CalledProcessError: 804 return [] 805 806 def _find_latest_csv(self, commit: str) -> Path: 807 """Finds or creates a CSV to which to write new entries. 808 809 - Check for untracked CSVs. Use the most recently modified file, if any. 810 - Check for CSVs added in HEAD, if HEAD is not an ancestor of commit. 811 Use the most recently modified file, if any. 812 - If no untracked or committed files were found, create a new file. 813 """ 814 815 # Prioritize untracked files in the directory database. 816 untracked_changes = self._git_paths( 817 ['ls-files', '--others', '--exclude-standard'] 818 ) 819 if untracked_changes: 820 return _most_recently_modified_file(untracked_changes) 821 822 # Check if HEAD is an ancestor of the base commit. This checks whether 823 # the top commit has been merged or not. If it has been merged, create a 824 # new CSV to use. Otherwise, check if a CSV was added in the commit. 825 head_is_not_merged = ( 826 subprocess.run( 827 ['git', 'merge-base', '--is-ancestor', 'HEAD', commit], 828 cwd=self.path, 829 stdout=subprocess.DEVNULL, 830 stderr=subprocess.DEVNULL, 831 ).returncode 832 != 0 833 ) 834 835 if head_is_not_merged: 836 # Find CSVs added in the top commit. 837 csvs_from_top_commit = self._git_paths( 838 [ 839 'diff', 840 '--name-only', 841 '--diff-filter=A', 842 '--relative', 843 'HEAD~', 844 ] 845 ) 846 847 if csvs_from_top_commit: 848 return _most_recently_modified_file(csvs_from_top_commit) 849 850 return self._create_filename() 851 852 def _create_filename(self) -> Path: 853 """Generates a unique filename not in the directory.""" 854 # Tracked and untracked files do not exist in the repo. 855 while (file := self.path / f'{uuid4().hex}{DIR_DB_SUFFIX}').exists(): 856 pass 857 return file 858 859 def add_and_discard_temporary( 860 self, entries: Iterable[TokenizedStringEntry], commit: str 861 ) -> None: 862 """Adds new entries and discards temporary entries on disk. 863 864 - Find the latest CSV in the directory database or create a new one. 865 - Delete entries in the latest CSV that are not in the entries passed to 866 this function. 867 - Add the new entries to this database. 868 - Overwrite the latest CSV with only the newly added entries. 869 """ 870 # Find entries not currently in the database. 871 added = Database(entries) 872 new_entries = added.difference(self) 873 874 csv_path = self._find_latest_csv(commit) 875 if csv_path.exists(): 876 # Loading the CSV as a DatabaseFile. 877 csv_db = DatabaseFile.load(csv_path) 878 879 # Delete entries added in the CSV, but not added in this function. 880 for key in (e.key() for e in csv_db.difference(added).entries()): 881 del self._database[key] 882 del csv_db._database[key] # pylint: disable=protected-access 883 884 csv_db.add(new_entries.entries()) 885 csv_db.write_to_file() 886 elif new_entries: # If the CSV does not exist, write all new tokens. 887 with csv_path.open('wb') as fd: 888 write_csv(new_entries, fd) 889 890 self.add(new_entries.entries()) 891