xref: /aosp_15_r20/external/pigweed/pw_tokenizer/py/detokenize_proto_test.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests decoding a proto with tokenized fields."""
16
17import base64
18import unittest
19
20from pw_tokenizer_tests.detokenize_proto_test_pb2 import TheMessage
21
22from pw_tokenizer import detokenize, encode, tokens
23from pw_tokenizer.proto import detokenize_fields, decode_optionally_tokenized
24
25_DATABASE = tokens.Database(
26    [
27        tokens.TokenizedStringEntry(0xAABBCCDD, "Luke, we're gonna have %s"),
28        tokens.TokenizedStringEntry(0x12345678, "This string has a $oeQAAA=="),
29        tokens.TokenizedStringEntry(0x0000E4A1, "recursive token"),
30    ]
31)
32_DETOKENIZER = detokenize.Detokenizer(_DATABASE)
33
34
35class TestDetokenizeProtoFields(unittest.TestCase):
36    """Tests detokenizing optionally tokenized proto fields."""
37
38    def test_plain_text(self) -> None:
39        proto = TheMessage(message=b'boring conversation anyway!')
40        detokenize_fields(_DETOKENIZER, proto)
41        self.assertEqual(proto.message, b'boring conversation anyway!')
42
43    def test_binary(self) -> None:
44        proto = TheMessage(message=b'\xDD\xCC\xBB\xAA\x07company')
45        detokenize_fields(_DETOKENIZER, proto)
46        self.assertEqual(proto.message, b"Luke, we're gonna have company")
47
48    def test_binary_missing_arguments(self) -> None:
49        proto = TheMessage(message=b'\xDD\xCC\xBB\xAA')
50        detokenize_fields(_DETOKENIZER, proto)
51        self.assertEqual(proto.message, b"Luke, we're gonna have %s")
52
53    def test_recursive_binary(self) -> None:
54        proto = TheMessage(message=b'\x78\x56\x34\x12')
55        detokenize_fields(_DETOKENIZER, proto)
56        self.assertEqual(proto.message, b"This string has a recursive token")
57
58    def test_base64(self) -> None:
59        base64_msg = encode.prefixed_base64(b'\xDD\xCC\xBB\xAA\x07company')
60        proto = TheMessage(message=base64_msg.encode())
61        detokenize_fields(_DETOKENIZER, proto)
62        self.assertEqual(proto.message, b"Luke, we're gonna have company")
63
64    def test_recursive_base64(self) -> None:
65        base64_msg = encode.prefixed_base64(b'\x78\x56\x34\x12')
66        proto = TheMessage(message=base64_msg.encode())
67        detokenize_fields(_DETOKENIZER, proto)
68        self.assertEqual(proto.message, b"This string has a recursive token")
69
70    def test_plain_text_with_prefixed_base64(self) -> None:
71        base64_msg = encode.prefixed_base64(b'\xDD\xCC\xBB\xAA\x09pancakes!')
72        proto = TheMessage(message=f'Good morning, {base64_msg}'.encode())
73        detokenize_fields(_DETOKENIZER, proto)
74        self.assertEqual(
75            proto.message, b"Good morning, Luke, we're gonna have pancakes!"
76        )
77
78    def test_unknown_token_not_utf8(self) -> None:
79        proto = TheMessage(message=b'\xFE\xED\xF0\x0D')
80        detokenize_fields(_DETOKENIZER, proto)
81        self.assertEqual(
82            proto.message.decode(), encode.prefixed_base64(b'\xFE\xED\xF0\x0D')
83        )
84
85    def test_only_control_characters(self) -> None:
86        proto = TheMessage(message=b'\1\2\3\4')
87        detokenize_fields(_DETOKENIZER, proto)
88        self.assertEqual(
89            proto.message.decode(), encode.prefixed_base64(b'\1\2\3\4')
90        )
91
92    def test_no_detokenizer_plain_text(self) -> None:
93        proto = TheMessage(message=b'boring conversation anyway!')
94        detokenize_fields(None, proto)
95        self.assertEqual(proto.message, b'boring conversation anyway!')
96
97    def test_no_detokenizer_unknown_token_not_utf8(self) -> None:
98        proto = TheMessage(message=b'\xFE\xED\xF0\x0D')
99        detokenize_fields(None, proto)
100        self.assertEqual(
101            proto.message.decode(), encode.prefixed_base64(b'\xFE\xED\xF0\x0D')
102        )
103
104    def test_no_detokenizer_only_control_characters(self) -> None:
105        proto = TheMessage(message=b'\1\2\3\4')
106        detokenize_fields(None, proto)
107        self.assertEqual(
108            proto.message.decode(), encode.prefixed_base64(b'\1\2\3\4')
109        )
110
111
112class TestDecodeOptionallyTokenized(unittest.TestCase):
113    """Tests optional detokenization directly."""
114
115    def setUp(self) -> None:
116        db = tokens.Database(
117            [
118                tokens.TokenizedStringEntry(0, 'cheese'),
119                tokens.TokenizedStringEntry(1, 'on pizza'),
120                tokens.TokenizedStringEntry(2, 'is quite good'),
121                tokens.TokenizedStringEntry(3, 'they say'),
122            ]
123        )
124        self.detok = detokenize.Detokenizer(db)
125        self.detok_tilde_prefix = detokenize.Detokenizer(db, prefix='~')
126
127    def test_found_binary_token(self) -> None:
128        self.assertEqual(
129            'on pizza',
130            decode_optionally_tokenized(self.detok, b'\x01\x00\x00\x00'),
131        )
132
133    def test_missing_binary_token(self) -> None:
134        self.assertEqual(
135            '$' + base64.b64encode(b'\xD5\x8A\xF9\x2A\x8A').decode(),
136            decode_optionally_tokenized(self.detok, b'\xD5\x8A\xF9\x2A\x8A'),
137        )
138
139    def test_found_b64_token(self) -> None:
140        b64_bytes = b'$' + base64.b64encode(b'\x03\x00\x00\x00')
141        self.assertEqual(
142            'they say', decode_optionally_tokenized(self.detok, b64_bytes)
143        )
144
145    def test_missing_b64_token(self) -> None:
146        b64_bytes = b'$' + base64.b64encode(b'\xD5\x8A\xF9\x2A\x8A')
147        self.assertEqual(
148            b64_bytes.decode(),
149            decode_optionally_tokenized(self.detok, b64_bytes),
150        )
151
152    def test_found_alternate_prefix(self) -> None:
153        b64_bytes = b'~' + base64.b64encode(b'\x00\x00\x00\x00')
154        self.assertEqual(
155            'cheese',
156            decode_optionally_tokenized(self.detok_tilde_prefix, b64_bytes),
157        )
158
159    def test_missing_alternate_prefix(self) -> None:
160        b64_bytes = b'~' + base64.b64encode(b'\x02\x00\x00\x00')
161        self.assertEqual(
162            b64_bytes.decode(),
163            decode_optionally_tokenized(self.detok, b64_bytes),
164        )
165
166    def test_no_detokenizer_binary(self) -> None:
167        data = b'\x01\x00\x00\x00'
168        self.assertEqual(
169            encode.prefixed_base64(data),
170            decode_optionally_tokenized(None, data),
171        )
172
173    def test_no_detokenizer_printable_utf8(self) -> None:
174        self.assertEqual(
175            'this\tis\r\nsome\ntext',
176            decode_optionally_tokenized(None, b'this\tis\r\nsome\ntext'),
177        )
178
179    def test_no_detokenizer_nonprintable_utf8(self) -> None:
180        data = b'\a\0wh\nat?'
181        self.assertEqual(
182            encode.prefixed_base64(data),
183            decode_optionally_tokenized(None, data),
184        )
185
186
187if __name__ == '__main__':
188    unittest.main()
189