xref: /aosp_15_r20/external/pigweed/pw_tokenizer/py/tokens_test.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests for the tokens module."""
16
17from datetime import datetime
18import io
19import logging
20from pathlib import Path
21import shutil
22import tempfile
23from typing import Iterator
24import unittest
25
26from pw_tokenizer import tokens
27from pw_tokenizer.tokens import c_hash, DIR_DB_SUFFIX, _LOG
28
29CSV_DATABASE = '''\
3000000000,2019-06-10,"",""
31141c35d5,          ,"","The answer: ""%s"""
322db1515f,          ,"","%u%d%02x%X%hu%hhu%d%ld%lu%lld%llu%c%c%c"
332e668cd6,2019-06-11,"","Jello, world!"
3431631781,          ,"","%d"
3561fd1e26,          ,"","%ld"
3668ab92da,          ,"","%s there are %x (%.2f) of them%c"
377b940e2a,          ,"","Hello %s! %hd %e"
38851beeb6,          ,"","%u %d"
39881436a0,          ,"","The answer is: %s"
40ad002c97,          ,"","%llx"
41b3653e13,2019-06-12,"","Jello!"
42b912567b,          ,"","%x%lld%1.2f%s"
43cc6d3131,2020-01-01,"","Jello?"
44e13b0f94,          ,"","%llu"
45e65aefef,2019-06-10,"","Won't fit : %s%d"
46'''
47
48# The date 2019-06-10 is 07E3-06-0A in hex. In database order, it's 0A 06 E3 07.
49BINARY_DATABASE = (
50    b'TOKENS\x00\x00\x10\x00\x00\x00\0\0\0\0'  # header (0x10 entries)
51    b'\x00\x00\x00\x00\x0a\x06\xe3\x07'  # 0x01
52    b'\xd5\x35\x1c\x14\xff\xff\xff\xff'  # 0x02
53    b'\x5f\x51\xb1\x2d\xff\xff\xff\xff'  # 0x03
54    b'\xd6\x8c\x66\x2e\x0b\x06\xe3\x07'  # 0x04
55    b'\x81\x17\x63\x31\xff\xff\xff\xff'  # 0x05
56    b'\x26\x1e\xfd\x61\xff\xff\xff\xff'  # 0x06
57    b'\xda\x92\xab\x68\xff\xff\xff\xff'  # 0x07
58    b'\x2a\x0e\x94\x7b\xff\xff\xff\xff'  # 0x08
59    b'\xb6\xee\x1b\x85\xff\xff\xff\xff'  # 0x09
60    b'\xa0\x36\x14\x88\xff\xff\xff\xff'  # 0x0a
61    b'\x97\x2c\x00\xad\xff\xff\xff\xff'  # 0x0b
62    b'\x13\x3e\x65\xb3\x0c\x06\xe3\x07'  # 0x0c
63    b'\x7b\x56\x12\xb9\xff\xff\xff\xff'  # 0x0d
64    b'\x31\x31\x6d\xcc\x01\x01\xe4\x07'  # 0x0e
65    b'\x94\x0f\x3b\xe1\xff\xff\xff\xff'  # 0x0f
66    b'\xef\xef\x5a\xe6\x0a\x06\xe3\x07'  # 0x10
67    b'\x00'
68    b'The answer: "%s"\x00'
69    b'%u%d%02x%X%hu%hhu%d%ld%lu%lld%llu%c%c%c\x00'
70    b'Jello, world!\x00'
71    b'%d\x00'
72    b'%ld\x00'
73    b'%s there are %x (%.2f) of them%c\x00'
74    b'Hello %s! %hd %e\x00'
75    b'%u %d\x00'
76    b'The answer is: %s\x00'
77    b'%llx\x00'
78    b'Jello!\x00'
79    b'%x%lld%1.2f%s\x00'
80    b'Jello?\x00'
81    b'%llu\x00'
82    b'Won\'t fit : %s%d\x00'
83)
84
85INVALID_CSV = """\
861,,"Whoa there!"
872,this is totally invalid,"Whoa there!"
883,,"This one's OK"
89,,"Also broken"
905,1845-02-02,"I'm %s fine"
916,"Missing fields"
92"""
93
94CSV_DATABASE_2 = '''\
9500000000,          ,"",""
96141c35d5,          ,"","The answer: ""%s"""
9729aef586,          ,"","1234"
982b78825f,          ,"","[:-)"
992e668cd6,          ,"","Jello, world!"
10031631781,          ,"","%d"
10161fd1e26,          ,"","%ld"
10268ab92da,          ,"","%s there are %x (%.2f) of them%c"
1037b940e2a,          ,"","Hello %s! %hd %e"
1047da55d52,          ,"",">:-[]"
1057f35a9a5,          ,"","TestName"
106851beeb6,          ,"","%u %d"
107881436a0,          ,"","The answer is: %s"
10888808930,          ,"","%u%d%02x%X%hu%hhd%d%ld%lu%lld%llu%c%c%c"
10992723f44,          ,"","???"
110a09d6698,          ,"","won-won-won-wonderful"
111aa9ffa66,          ,"","void pw::tokenizer::{anonymous}::TestName()"
112ad002c97,          ,"","%llx"
113b3653e13,          ,"","Jello!"
114cc6d3131,          ,"","Jello?"
115e13b0f94,          ,"","%llu"
116e65aefef,          ,"","Won't fit : %s%d"
117'''
118
119CSV_DATABASE_3 = """\
12017fa86d3,          ,"TEST_DOMAIN","hello"
12118c5017c,          ,"TEST_DOMAIN","yes"
12259b2701c,          ,"TEST_DOMAIN","The answer was: %s"
123881436a0,          ,"TEST_DOMAIN","The answer is: %s"
124d18ada0f,          ,"TEST_DOMAIN","something"
125"""
126
127CSV_DATABASE_4 = '''\
12800000000,          ,"",""
129141c35d5,          ,"","The answer: ""%s"""
13029aef586,          ,"","1234"
1312b78825f,          ,"","[:-)"
1322e668cd6,          ,"","Jello, world!"
13331631781,          ,"","%d"
13461fd1e26,          ,"","%ld"
13568ab92da,          ,"","%s there are %x (%.2f) of them%c"
1367b940e2a,          ,"","Hello %s! %hd %e"
1377da55d52,          ,"",">:-[]"
1387f35a9a5,          ,"","TestName"
139851beeb6,          ,"","%u %d"
140881436a0,          ,"","The answer is: %s"
14188808930,          ,"","%u%d%02x%X%hu%hhd%d%ld%lu%lld%llu%c%c%c"
14292723f44,          ,"","???"
143a09d6698,          ,"","won-won-won-wonderful"
144aa9ffa66,          ,"","void pw::tokenizer::{anonymous}::TestName()"
145ad002c97,          ,"","%llx"
146b3653e13,          ,"","Jello!"
147cc6d3131,          ,"","Jello?"
148e13b0f94,          ,"","%llu"
149e65aefef,          ,"","Won't fit : %s%d"
15017fa86d3,          ,"TEST_DOMAIN","hello"
15118c5017c,          ,"TEST_DOMAIN","yes"
15259b2701c,          ,"TEST_DOMAIN","The answer was: %s"
153881436a0,          ,"TEST_DOMAIN","The answer is: %s"
154d18ada0f,          ,"TEST_DOMAIN","something"
155'''
156
157CSV_DATABASE_5 = """\
15800000001,1998-09-04,"Domain","hello"
15900000002,          ,"","yes"
16000000002,          ,"Domain","No!"
16100000004,          ,"?","The answer is: %s"
162"""
163
164CSV_DATABASE_5_NO_DOMAIN = """\
16500000001,1998-09-04,"hello"
16600000002,          ,"yes"
16700000002,          ,"No!"
16800000004,          ,"The answer is: %s"
169"""
170
171CSV_DATABASE_6_DOMAIN_WHITESPACE = """\
17200000001,2001-09-04,"Domain 1","hello"
17300000002,          ,"\t","yes"
17400000002,          ,"  Domain\t20\n","No!"
17500000004,          ,"  ?   ","The answer is: %s"
176"""
177
178
179def read_db_from_csv(csv_str: str) -> tokens.Database:
180    with io.StringIO(csv_str) as csv_db:
181        return tokens.Database(tokens.parse_csv(csv_db))
182
183
184def _entries(*strings: str) -> Iterator[tokens.TokenizedStringEntry]:
185    for string in strings:
186        yield tokens.TokenizedStringEntry(c_hash(string), string)
187
188
189class TokenDatabaseTest(unittest.TestCase):
190    """Tests the token database class."""
191
192    def test_csv(self) -> None:
193        db = read_db_from_csv(CSV_DATABASE)
194        self.assertEqual(str(db), CSV_DATABASE)
195
196        db = read_db_from_csv(CSV_DATABASE_4)
197        self.assertEqual(str(db), CSV_DATABASE_4)
198
199        db = read_db_from_csv('')
200        self.assertEqual(str(db), '')
201
202    def test_csv_loads_domains(self) -> None:
203        db = read_db_from_csv(CSV_DATABASE_5)
204        self.assertEqual(
205            db.token_to_entries[1],
206            [
207                tokens.TokenizedStringEntry(
208                    token=1,
209                    string='hello',
210                    domain='Domain',
211                    date_removed=datetime(1998, 9, 4),
212                )
213            ],
214        )
215        self.assertEqual(
216            db.token_to_entries[2],
217            [
218                tokens.TokenizedStringEntry(token=2, string='yes', domain=''),
219                tokens.TokenizedStringEntry(
220                    token=2, string='No!', domain='Domain'
221                ),
222            ],
223        )
224        self.assertEqual(
225            db.token_to_entries[4],
226            [
227                tokens.TokenizedStringEntry(
228                    token=4, string='The answer is: %s', domain='?'
229                ),
230            ],
231        )
232
233    def test_csv_legacy_no_domain_database(self) -> None:
234        db = read_db_from_csv(CSV_DATABASE_5_NO_DOMAIN)
235        self.assertEqual(
236            db.token_to_entries[1],
237            [
238                tokens.TokenizedStringEntry(
239                    token=1,
240                    string='hello',
241                    domain='',
242                    date_removed=datetime(1998, 9, 4),
243                )
244            ],
245        )
246        self.assertEqual(
247            db.token_to_entries[2],
248            [
249                tokens.TokenizedStringEntry(token=2, string='No!', domain=''),
250                tokens.TokenizedStringEntry(token=2, string='yes', domain=''),
251            ],
252        )
253        self.assertEqual(
254            db.token_to_entries[4],
255            [
256                tokens.TokenizedStringEntry(
257                    token=4, string='The answer is: %s', domain=''
258                ),
259            ],
260        )
261
262    def test_csv_formatting(self) -> None:
263        db = read_db_from_csv('')
264        self.assertEqual(str(db), '')
265
266        db = read_db_from_csv('abc123,2048-04-01,Fake string\n')
267        self.assertEqual(str(db), '00abc123,2048-04-01,"","Fake string"\n')
268
269        db = read_db_from_csv(
270            '1,1990-01-01,"","Quotes"""\n' '0,1990-02-01,"Commas,"",,"\n'
271        )
272        self.assertEqual(
273            str(db),
274            (
275                '00000000,1990-02-01,"","Commas,"",,"\n'
276                '00000001,1990-01-01,"","Quotes"""\n'
277            ),
278        )
279
280    def test_bad_csv(self) -> None:
281        with self.assertLogs(_LOG, logging.ERROR) as logs:
282            db = read_db_from_csv(INVALID_CSV)
283
284        self.assertGreaterEqual(len(logs.output), 3)
285        self.assertEqual(len(db.token_to_entries), 3)
286
287        self.assertEqual(db.token_to_entries[1][0].string, 'Whoa there!')
288        self.assertFalse(db.token_to_entries[2])
289        self.assertNotIn(2, db.token_to_entries)
290        self.assertEqual(db.token_to_entries[3][0].string, "This one's OK")
291        self.assertFalse(db.token_to_entries[4])
292        self.assertNotIn(4, db.token_to_entries)
293        self.assertEqual(db.token_to_entries[5][0].string, "I'm %s fine")
294        self.assertFalse(db.token_to_entries[6])
295        self.assertNotIn(6, db.token_to_entries)
296
297    def test_lookup(self) -> None:
298        db = read_db_from_csv(CSV_DATABASE)
299        self.assertSequenceEqual(db.token_to_entries[0x9999], [])
300        self.assertNotIn(0x9999, db.token_to_entries)
301        self.assertIsNone(db.token_to_entries.get(0x9999))
302
303        matches = db.token_to_entries[0x2E668CD6]
304        self.assertEqual(len(matches), 1)
305        jello = matches[0]
306
307        self.assertEqual(jello.token, 0x2E668CD6)
308        self.assertEqual(jello.string, 'Jello, world!')
309        self.assertEqual(jello.date_removed, datetime(2019, 6, 11))
310
311        matches = db.token_to_entries[0xE13B0F94]
312        self.assertEqual(len(matches), 1)
313        llu = matches[0]
314        self.assertEqual(llu.token, 0xE13B0F94)
315        self.assertEqual(llu.string, '%llu')
316        self.assertIsNone(llu.date_removed)
317
318        (answer,) = db.token_to_entries[0x141C35D5]
319        self.assertEqual(answer.string, 'The answer: "%s"')
320
321    def test_domains(self) -> None:
322        """Tests the domains mapping."""
323        db = tokens.Database(
324            [
325                tokens.TokenizedStringEntry(1, 'one', 'D1'),
326                tokens.TokenizedStringEntry(2, 'two', 'D1'),
327                tokens.TokenizedStringEntry(1, 'one', 'D2'),
328                tokens.TokenizedStringEntry(1, 'one!', 'D3', datetime.min),
329                tokens.TokenizedStringEntry(3, 'zzz', 'D1'),
330                tokens.TokenizedStringEntry(3, 'three', 'D1', datetime.min),
331                tokens.TokenizedStringEntry(3, 'three', 'D1'),
332                tokens.TokenizedStringEntry(3, 'zzzz', 'D1'),
333            ]
334        )
335        self.assertEqual(db.domains.keys(), {'D1', 'D2', 'D3'})
336        self.assertEqual(
337            db.domains['D1'],
338            {
339                1: [tokens.TokenizedStringEntry(1, 'one', 'D1')],
340                2: [tokens.TokenizedStringEntry(2, 'two', 'D1')],
341                3: [
342                    tokens.TokenizedStringEntry(3, 'three', 'D1'),
343                    tokens.TokenizedStringEntry(3, 'zzz', 'D1'),
344                    tokens.TokenizedStringEntry(3, 'zzzz', 'D1'),
345                ],
346            },
347        )
348        self.assertEqual(
349            db.domains['D2'],
350            {
351                1: [tokens.TokenizedStringEntry(1, 'one', 'D2')],
352            },
353        )
354        self.assertEqual(
355            db.domains['D3'],
356            {
357                1: [tokens.TokenizedStringEntry(1, 'one!', 'D3', datetime.min)],
358            },
359        )
360        self.assertEqual(db.domains['not a domain!'], {})
361        self.assertNotIn('not a domain!', db.domains)
362        self.assertIsNone(db.domains.get('not a domain'))
363
364    def test_collisions(self) -> None:
365        hash_1 = tokens.c_hash('o000', 96)
366        hash_2 = tokens.c_hash('0Q1Q', 96)
367        self.assertEqual(hash_1, hash_2)
368
369        db = tokens.Database.from_strings(['o000', '0Q1Q'])
370
371        self.assertEqual(len(db.token_to_entries[hash_1]), 2)
372        self.assertCountEqual(
373            [entry.string for entry in db.token_to_entries[hash_1]],
374            ['o000', '0Q1Q'],
375        )
376
377    def test_purge(self) -> None:
378        db = read_db_from_csv(CSV_DATABASE)
379        original_length = len(db.token_to_entries)
380
381        self.assertEqual(db.token_to_entries[0][0].string, '')
382        self.assertEqual(db.token_to_entries[0x31631781][0].string, '%d')
383        self.assertEqual(
384            db.token_to_entries[0x2E668CD6][0].string, 'Jello, world!'
385        )
386        self.assertEqual(db.token_to_entries[0xB3653E13][0].string, 'Jello!')
387        self.assertEqual(db.token_to_entries[0xCC6D3131][0].string, 'Jello?')
388        self.assertEqual(
389            db.token_to_entries[0xE65AEFEF][0].string, "Won't fit : %s%d"
390        )
391
392        db.purge(datetime(2019, 6, 11))
393        self.assertLess(len(db.token_to_entries), original_length)
394        self.assertEqual(len(db.token_to_entries), len(db.entries()))
395
396        self.assertFalse(db.token_to_entries[0])
397        self.assertNotIn(0, db.token_to_entries)
398        self.assertSequenceEqual(db.token_to_entries[0], [])
399        self.assertEqual(db.token_to_entries[0x31631781][0].string, '%d')
400        self.assertFalse(db.token_to_entries[0x2E668CD6])
401        self.assertNotIn(0x2E668CD6, db.token_to_entries)
402        self.assertEqual(db.token_to_entries[0xB3653E13][0].string, 'Jello!')
403        self.assertEqual(db.token_to_entries[0xCC6D3131][0].string, 'Jello?')
404        self.assertFalse(db.token_to_entries[0xE65AEFEF])
405        self.assertNotIn(0xE65AEFEF, db.token_to_entries)
406
407    def test_merge(self) -> None:
408        """Tests the tokens.Database merge method."""
409
410        db = tokens.Database()
411
412        # Test basic merging into an empty database.
413        db.merge(
414            tokens.Database(
415                [
416                    tokens.TokenizedStringEntry(
417                        1, 'one', date_removed=datetime.min
418                    ),
419                    tokens.TokenizedStringEntry(
420                        2, 'two', 'domain', date_removed=datetime.min
421                    ),
422                ]
423            )
424        )
425        self.assertEqual({str(e) for e in db.entries()}, {'one', 'two'})
426        self.assertEqual(db.token_to_entries[1][0].date_removed, datetime.min)
427        self.assertEqual(db.token_to_entries[2][0].date_removed, datetime.min)
428
429        # Test merging in an entry with a removal date.
430        db.merge(
431            tokens.Database(
432                [
433                    tokens.TokenizedStringEntry(3, 'three'),
434                    tokens.TokenizedStringEntry(
435                        4, 'four', date_removed=datetime.min
436                    ),
437                ]
438            )
439        )
440        self.assertEqual(
441            {str(e) for e in db.entries()}, {'one', 'two', 'three', 'four'}
442        )
443        self.assertIsNone(db.token_to_entries[3][0].date_removed)
444        self.assertEqual(db.token_to_entries[4][0].date_removed, datetime.min)
445
446        # Test merging in one entry.
447        db.merge(
448            tokens.Database(
449                [
450                    tokens.TokenizedStringEntry(5, 'five'),
451                ]
452            )
453        )
454        self.assertEqual(
455            {str(e) for e in db.entries()},
456            {'one', 'two', 'three', 'four', 'five'},
457        )
458        self.assertEqual(db.token_to_entries[4][0].date_removed, datetime.min)
459        self.assertIsNone(db.token_to_entries[5][0].date_removed)
460
461        # Merge in repeated entries different removal dates.
462        db.merge(
463            tokens.Database(
464                [
465                    tokens.TokenizedStringEntry(
466                        4, 'four', date_removed=datetime.max
467                    ),
468                    tokens.TokenizedStringEntry(
469                        5, 'five', date_removed=datetime.max
470                    ),
471                ]
472            )
473        )
474        self.assertEqual(len(db.entries()), 5)
475        self.assertEqual(
476            {str(e) for e in db.entries()},
477            {'one', 'two', 'three', 'four', 'five'},
478        )
479        self.assertEqual(db.token_to_entries[4][0].date_removed, datetime.max)
480        self.assertIsNone(db.token_to_entries[5][0].date_removed)
481
482        # Merge in the same repeated entries now without removal dates.
483        db.merge(
484            tokens.Database(
485                [
486                    tokens.TokenizedStringEntry(4, 'four'),
487                    tokens.TokenizedStringEntry(5, 'five'),
488                ]
489            )
490        )
491        self.assertEqual(len(db.entries()), 5)
492        self.assertEqual(
493            {str(e) for e in db.entries()},
494            {'one', 'two', 'three', 'four', 'five'},
495        )
496        self.assertIsNone(db.token_to_entries[4][0].date_removed)
497        self.assertIsNone(db.token_to_entries[5][0].date_removed)
498
499        # Merge in an empty databsse.
500        db.merge(tokens.Database([]))
501        self.assertEqual(
502            {str(e) for e in db.entries()},
503            {'one', 'two', 'three', 'four', 'five'},
504        )
505
506    def test_merge_multiple_datbases_in_one_call(self) -> None:
507        """Tests the merge and merged methods with multiple databases."""
508        db = tokens.Database.merged(
509            tokens.Database(
510                [
511                    tokens.TokenizedStringEntry(
512                        1, 'one', date_removed=datetime.max
513                    )
514                ]
515            ),
516            tokens.Database(
517                [
518                    tokens.TokenizedStringEntry(
519                        2, 'two', date_removed=datetime.min
520                    )
521                ]
522            ),
523            tokens.Database(
524                [
525                    tokens.TokenizedStringEntry(
526                        1, 'one', date_removed=datetime.min
527                    )
528                ]
529            ),
530        )
531        self.assertEqual({str(e) for e in db.entries()}, {'one', 'two'})
532
533        db.merge(
534            tokens.Database(
535                [
536                    tokens.TokenizedStringEntry(
537                        4, 'four', date_removed=datetime.max
538                    )
539                ]
540            ),
541            tokens.Database(
542                [
543                    tokens.TokenizedStringEntry(
544                        2, 'two', date_removed=datetime.max
545                    )
546                ]
547            ),
548            tokens.Database(
549                [
550                    tokens.TokenizedStringEntry(
551                        3, 'three', date_removed=datetime.min
552                    )
553                ]
554            ),
555        )
556        self.assertEqual(
557            {str(e) for e in db.entries()}, {'one', 'two', 'three', 'four'}
558        )
559
560    def test_merge_same_tokens_different_domains(self) -> None:
561        db = tokens.Database.merged(
562            tokens.Database([tokens.TokenizedStringEntry(2, 'two', 'D1')]),
563            tokens.Database([tokens.TokenizedStringEntry(1, 'one', 'D2')]),
564            tokens.Database([tokens.TokenizedStringEntry(1, 'one', 'D2')]),
565            tokens.Database([tokens.TokenizedStringEntry(1, 'one!', 'D3')]),
566            tokens.Database([tokens.TokenizedStringEntry(1, 'one', 'D1')]),
567        )
568        self.assertEqual(
569            sorted(db.entries()),
570            sorted(
571                [
572                    tokens.TokenizedStringEntry(1, 'one', 'D1'),
573                    tokens.TokenizedStringEntry(2, 'two', 'D1'),
574                    tokens.TokenizedStringEntry(1, 'one', 'D2'),
575                    tokens.TokenizedStringEntry(1, 'one!', 'D3'),
576                ]
577            ),
578        )
579        self.assertEqual(
580            db.token_to_entries[1],
581            [
582                tokens.TokenizedStringEntry(1, 'one', 'D1'),
583                tokens.TokenizedStringEntry(1, 'one', 'D2'),
584                tokens.TokenizedStringEntry(1, 'one!', 'D3'),
585            ],
586        )
587
588    def test_entry_counts(self) -> None:
589        self.assertEqual(len(CSV_DATABASE.splitlines()), 16)
590
591        db = read_db_from_csv(CSV_DATABASE)
592        self.assertEqual(len(db.entries()), 16)
593        self.assertEqual(len(db.token_to_entries), 16)
594
595        # Add two strings with the same hash.
596        db.add(_entries('o000', '0Q1Q'))
597
598        self.assertEqual(len(db.entries()), 18)
599        self.assertEqual(len(db.token_to_entries), 17)
600
601    def test_mark_removed(self) -> None:
602        """Tests that date_removed field is set by mark_removed."""
603        db = tokens.Database.from_strings(
604            ['MILK', 'apples', 'oranges', 'CHEESE', 'pears']
605        )
606
607        self.assertTrue(
608            all(entry.date_removed is None for entry in db.entries())
609        )
610        date_1 = datetime(1, 2, 3)
611
612        db.mark_removed(_entries('apples', 'oranges', 'pears'), date_1)
613
614        self.assertEqual(
615            db.token_to_entries[c_hash('MILK')][0].date_removed, date_1
616        )
617        self.assertEqual(
618            db.token_to_entries[c_hash('CHEESE')][0].date_removed, date_1
619        )
620
621        now = datetime.now()
622        db.mark_removed(_entries('MILK', 'CHEESE', 'pears'))
623
624        # New strings are not added or re-added in mark_removed().
625        milk_date = db.token_to_entries[c_hash('MILK')][0].date_removed
626        assert milk_date is not None
627        self.assertGreaterEqual(milk_date, date_1)
628
629        cheese_date = db.token_to_entries[c_hash('CHEESE')][0].date_removed
630        assert cheese_date is not None
631        self.assertGreaterEqual(cheese_date, date_1)
632
633        # These strings were removed.
634        apples_date = db.token_to_entries[c_hash('apples')][0].date_removed
635        assert apples_date is not None
636        self.assertGreaterEqual(apples_date, now)
637
638        oranges_date = db.token_to_entries[c_hash('oranges')][0].date_removed
639        assert oranges_date is not None
640        self.assertGreaterEqual(oranges_date, now)
641        self.assertIsNone(db.token_to_entries[c_hash('pears')][0].date_removed)
642
643    def test_add(self) -> None:
644        db = tokens.Database()
645        db.add(_entries('MILK', 'apples'))
646        self.assertEqual({e.string for e in db.entries()}, {'MILK', 'apples'})
647
648        db.add(_entries('oranges', 'CHEESE', 'pears'))
649        self.assertEqual(len(db.entries()), 5)
650
651        db.add(_entries('MILK', 'apples', 'only this one is new'))
652        self.assertEqual(len(db.entries()), 6)
653
654        db.add(_entries('MILK'))
655        self.assertEqual(
656            {e.string for e in db.entries()},
657            {
658                'MILK',
659                'apples',
660                'oranges',
661                'CHEESE',
662                'pears',
663                'only this one is new',
664            },
665        )
666
667    def test_add_duplicate_entries_keeps_none_as_removal_date(self) -> None:
668        db = tokens.Database()
669        db.add(
670            [
671                tokens.TokenizedStringEntry(1, 'Spam', '', datetime.now()),
672                tokens.TokenizedStringEntry(1, 'Spam', ''),
673                tokens.TokenizedStringEntry(1, 'Spam', '', datetime.min),
674            ]
675        )
676        self.assertEqual(len(db), 1)
677        self.assertIsNone(db.token_to_entries[1][0].date_removed)
678
679    def test_add_duplicate_entries_keeps_newest_removal_date(self) -> None:
680        db = tokens.Database()
681        db.add(
682            [
683                tokens.TokenizedStringEntry(1, 'Spam', '', datetime.now()),
684                tokens.TokenizedStringEntry(1, 'Spam', '', datetime.max),
685                tokens.TokenizedStringEntry(1, 'Spam', '', datetime.now()),
686                tokens.TokenizedStringEntry(1, 'Spam', '', datetime.min),
687            ]
688        )
689        self.assertEqual(len(db), 1)
690        self.assertEqual(db.token_to_entries[1][0].date_removed, datetime.max)
691
692    def test_difference(self) -> None:
693        first = tokens.Database(
694            [
695                tokens.TokenizedStringEntry(1, 'one'),
696                tokens.TokenizedStringEntry(2, 'two'),
697                tokens.TokenizedStringEntry(3, 'three'),
698            ]
699        )
700        second = tokens.Database(
701            [
702                tokens.TokenizedStringEntry(1, 'one'),
703                tokens.TokenizedStringEntry(3, 'three'),
704                tokens.TokenizedStringEntry(4, 'four'),
705            ]
706        )
707        difference = first.difference(second)
708        self.assertEqual({e.string for e in difference.entries()}, {'two'})
709
710    def test_tokens_by_domain(self) -> None:
711        db = read_db_from_csv(CSV_DATABASE_2)
712        self.assertEqual(db.domains.keys(), {''})
713        db = read_db_from_csv(CSV_DATABASE_3)
714        self.assertEqual(db.domains.keys(), {'TEST_DOMAIN'})
715        db = read_db_from_csv(CSV_DATABASE_4)
716        self.assertEqual(db.domains.keys(), {'', 'TEST_DOMAIN'})
717        db = read_db_from_csv(CSV_DATABASE_5)
718        self.assertEqual(db.domains.keys(), {'', '?', 'Domain'})
719
720    def test_binary_format_write(self) -> None:
721        db = read_db_from_csv(CSV_DATABASE)
722
723        with io.BytesIO() as fd:
724            tokens.write_binary(db, fd)
725            binary_db = fd.getvalue()
726
727        self.assertEqual(BINARY_DATABASE, binary_db)
728
729    def test_binary_format_parse(self) -> None:
730        with io.BytesIO(BINARY_DATABASE) as binary_db:
731            db = tokens.Database(tokens.parse_binary(binary_db))
732
733        self.assertEqual(str(db), CSV_DATABASE)
734
735
736class TestDatabaseFile(unittest.TestCase):
737    """Tests the DatabaseFile class."""
738
739    def setUp(self) -> None:
740        file = tempfile.NamedTemporaryFile(delete=False)
741        file.close()
742        self._path = Path(file.name)
743
744    def tearDown(self) -> None:
745        self._path.unlink()
746
747    def test_update_csv_file(self) -> None:
748        self._path.write_text(CSV_DATABASE)
749        db = tokens.DatabaseFile.load(self._path)
750        self.assertEqual(str(db), CSV_DATABASE)
751
752        db.add([tokens.TokenizedStringEntry(0xFFFFFFFF, 'New entry!', '')])
753
754        db.write_to_file()
755
756        self.assertEqual(
757            self._path.read_text(),
758            CSV_DATABASE + 'ffffffff,          ,"","New entry!"\n',
759        )
760
761    def test_csv_file_too_short_raises_exception(self) -> None:
762        self._path.write_text('1234')
763
764        with self.assertRaises(tokens.DatabaseFormatError):
765            tokens.DatabaseFile.load(self._path)
766
767    def test_csv_invalid_format_raises_exception(self) -> None:
768        self._path.write_text('MK34567890')
769
770        with self.assertRaises(tokens.DatabaseFormatError):
771            tokens.DatabaseFile.load(self._path)
772
773    def test_csv_not_utf8(self) -> None:
774        self._path.write_bytes(b'\x80' * 20)
775
776        with self.assertRaises(tokens.DatabaseFormatError):
777            tokens.DatabaseFile.load(self._path)
778
779
780class TestFilter(unittest.TestCase):
781    """Tests the filtering functionality."""
782
783    def setUp(self) -> None:
784        self.db = tokens.Database(
785            [
786                tokens.TokenizedStringEntry(1, 'Luke'),
787                tokens.TokenizedStringEntry(2, 'Leia'),
788                tokens.TokenizedStringEntry(2, 'Darth Vader'),
789                tokens.TokenizedStringEntry(2, 'Emperor Palpatine'),
790                tokens.TokenizedStringEntry(3, 'Han'),
791                tokens.TokenizedStringEntry(4, 'Chewbacca'),
792                tokens.TokenizedStringEntry(5, 'Darth Maul'),
793                tokens.TokenizedStringEntry(6, 'Han Solo'),
794            ]
795        )
796
797    def test_filter_include_single_regex(self) -> None:
798        self.db.filter(include=[' '])  # anything with a space
799        self.assertEqual(
800            set(e.string for e in self.db.entries()),
801            {'Darth Vader', 'Emperor Palpatine', 'Darth Maul', 'Han Solo'},
802        )
803
804    def test_filter_include_multiple_regexes(self) -> None:
805        self.db.filter(include=['Darth', 'cc', '^Han$'])
806        self.assertEqual(
807            set(e.string for e in self.db.entries()),
808            {'Darth Vader', 'Darth Maul', 'Han', 'Chewbacca'},
809        )
810
811    def test_filter_include_no_matches(self) -> None:
812        self.db.filter(include=['Gandalf'])
813        self.assertFalse(self.db.entries())
814
815    def test_filter_exclude_single_regex(self) -> None:
816        self.db.filter(exclude=['^[^L]'])
817        self.assertEqual(
818            set(e.string for e in self.db.entries()), {'Luke', 'Leia'}
819        )
820
821    def test_filter_exclude_multiple_regexes(self) -> None:
822        self.db.filter(exclude=[' ', 'Han', 'Chewbacca'])
823        self.assertEqual(
824            set(e.string for e in self.db.entries()), {'Luke', 'Leia'}
825        )
826
827    def test_filter_exclude_no_matches(self) -> None:
828        self.db.filter(exclude=['.*'])
829        self.assertFalse(self.db.entries())
830
831    def test_filter_include_and_exclude(self) -> None:
832        self.db.filter(include=[' '], exclude=['Darth', 'Emperor'])
833        self.assertEqual(set(e.string for e in self.db.entries()), {'Han Solo'})
834
835    def test_filter_neither_include_nor_exclude(self) -> None:
836        self.db.filter()
837        self.assertEqual(
838            set(e.string for e in self.db.entries()),
839            {
840                'Luke',
841                'Leia',
842                'Darth Vader',
843                'Emperor Palpatine',
844                'Han',
845                'Chewbacca',
846                'Darth Maul',
847                'Han Solo',
848            },
849        )
850
851    def test_csv_remove_domain_whitespace(self) -> None:
852        db = read_db_from_csv(CSV_DATABASE_6_DOMAIN_WHITESPACE)
853        self.assertEqual(
854            db.token_to_entries[1],
855            [
856                tokens.TokenizedStringEntry(
857                    token=1,
858                    string='hello',
859                    domain='Domain1',
860                    date_removed=datetime(2001, 9, 4),
861                )
862            ],
863        )
864        self.assertEqual(
865            db.token_to_entries[2],
866            [
867                tokens.TokenizedStringEntry(token=2, string='yes', domain=''),
868                tokens.TokenizedStringEntry(
869                    token=2, string='No!', domain='Domain20'
870                ),
871            ],
872        )
873        self.assertEqual(
874            db.token_to_entries[4],
875            [
876                tokens.TokenizedStringEntry(
877                    token=4, string='The answer is: %s', domain='?'
878                ),
879            ],
880        )
881
882
883class TestDirectoryDatabase(unittest.TestCase):
884    """Test DirectoryDatabase class is properly loaded."""
885
886    def setUp(self) -> None:
887        self._dir = Path(tempfile.mkdtemp('_pw_tokenizer_test'))
888        self._db_dir = self._dir / '_dir_database_test'
889        self._db_dir.mkdir(exist_ok=True)
890        self._db_csv = self._db_dir / f'first{DIR_DB_SUFFIX}'
891
892    def tearDown(self) -> None:
893        shutil.rmtree(self._dir)
894
895    def test_loading_empty_directory(self) -> None:
896        self.assertFalse(tokens.DatabaseFile.load(self._db_dir).entries())
897
898    def test_loading_a_single_file(self) -> None:
899        self._db_csv.write_text(CSV_DATABASE)
900        csv = tokens.DatabaseFile.load(self._db_csv)
901        directory_db = tokens.DatabaseFile.load(self._db_dir)
902        self.assertEqual(1, len(list(self._db_dir.iterdir())))
903        self.assertEqual(str(csv), str(directory_db))
904
905    def test_loading_multiples_files(self) -> None:
906        self._db_csv.write_text(CSV_DATABASE_3)
907        first_csv = tokens.DatabaseFile.load(self._db_csv)
908
909        path_to_second_csv = self._db_dir / f'second{DIR_DB_SUFFIX}'
910        path_to_second_csv.write_text(CSV_DATABASE_2)
911        second_csv = tokens.DatabaseFile.load(path_to_second_csv)
912
913        path_to_third_csv = self._db_dir / f'third{DIR_DB_SUFFIX}'
914        path_to_third_csv.write_text(CSV_DATABASE_4)
915        third_csv = tokens.DatabaseFile.load(path_to_third_csv)
916
917        all_databases_merged = tokens.Database.merged(
918            first_csv, second_csv, third_csv
919        )
920        directory_db = tokens.DatabaseFile.load(self._db_dir)
921        self.assertEqual(3, len(list(self._db_dir.iterdir())))
922        self.assertEqual(str(all_databases_merged), str(directory_db))
923
924    def test_loading_multiples_files_with_removal_dates(self) -> None:
925        self._db_csv.write_text(CSV_DATABASE)
926        first_csv = tokens.DatabaseFile.load(self._db_csv)
927
928        path_to_second_csv = self._db_dir / f'second{DIR_DB_SUFFIX}'
929        path_to_second_csv.write_text(CSV_DATABASE_2)
930        second_csv = tokens.DatabaseFile.load(path_to_second_csv)
931
932        path_to_third_csv = self._db_dir / f'third{DIR_DB_SUFFIX}'
933        path_to_third_csv.write_text(CSV_DATABASE_3)
934        third_csv = tokens.DatabaseFile.load(path_to_third_csv)
935
936        all_databases_merged = tokens.Database.merged(
937            first_csv, second_csv, third_csv
938        )
939        directory_db = tokens.DatabaseFile.load(self._db_dir)
940        self.assertEqual(3, len(list(self._db_dir.iterdir())))
941        self.assertEqual(str(all_databases_merged), str(directory_db))
942
943    def test_rewrite(self) -> None:
944        self._db_dir.joinpath('junk_file').write_text('should be ignored')
945
946        self._db_csv.write_text(CSV_DATABASE_3)
947        first_csv = tokens.DatabaseFile.load(self._db_csv)
948
949        path_to_second_csv = self._db_dir / f'second{DIR_DB_SUFFIX}'
950        path_to_second_csv.write_text(CSV_DATABASE_2)
951        second_csv = tokens.DatabaseFile.load(path_to_second_csv)
952
953        path_to_third_csv = self._db_dir / f'third{DIR_DB_SUFFIX}'
954        path_to_third_csv.write_text(CSV_DATABASE_4)
955        third_csv = tokens.DatabaseFile.load(path_to_third_csv)
956
957        all_databases_merged = tokens.Database.merged(
958            first_csv, second_csv, third_csv
959        )
960
961        directory_db = tokens.DatabaseFile.load(self._db_dir)
962        directory_db.write_to_file(rewrite=True)
963
964        self.assertEqual(1, len(list(self._db_dir.glob(f'*{DIR_DB_SUFFIX}'))))
965        self.assertEqual(
966            self._db_dir.joinpath('junk_file').read_text(), 'should be ignored'
967        )
968
969        directory_db = tokens.DatabaseFile.load(self._db_dir)
970        self.assertEqual(str(all_databases_merged), str(directory_db))
971
972
973if __name__ == '__main__':
974    unittest.main()
975