1#!/usr/bin/env python3 2# Copyright 2021 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests decoding a proto with tokenized fields.""" 16 17import base64 18import unittest 19 20from pw_tokenizer_tests.detokenize_proto_test_pb2 import TheMessage 21 22from pw_tokenizer import detokenize, encode, tokens 23from pw_tokenizer.proto import detokenize_fields, decode_optionally_tokenized 24 25_DATABASE = tokens.Database( 26 [ 27 tokens.TokenizedStringEntry(0xAABBCCDD, "Luke, we're gonna have %s"), 28 tokens.TokenizedStringEntry(0x12345678, "This string has a $oeQAAA=="), 29 tokens.TokenizedStringEntry(0x0000E4A1, "recursive token"), 30 ] 31) 32_DETOKENIZER = detokenize.Detokenizer(_DATABASE) 33 34 35class TestDetokenizeProtoFields(unittest.TestCase): 36 """Tests detokenizing optionally tokenized proto fields.""" 37 38 def test_plain_text(self) -> None: 39 proto = TheMessage(message=b'boring conversation anyway!') 40 detokenize_fields(_DETOKENIZER, proto) 41 self.assertEqual(proto.message, b'boring conversation anyway!') 42 43 def test_binary(self) -> None: 44 proto = TheMessage(message=b'\xDD\xCC\xBB\xAA\x07company') 45 detokenize_fields(_DETOKENIZER, proto) 46 self.assertEqual(proto.message, b"Luke, we're gonna have company") 47 48 def test_binary_missing_arguments(self) -> None: 49 proto = TheMessage(message=b'\xDD\xCC\xBB\xAA') 50 detokenize_fields(_DETOKENIZER, proto) 51 self.assertEqual(proto.message, b"Luke, we're gonna have %s") 52 53 def test_recursive_binary(self) -> None: 54 proto = TheMessage(message=b'\x78\x56\x34\x12') 55 detokenize_fields(_DETOKENIZER, proto) 56 self.assertEqual(proto.message, b"This string has a recursive token") 57 58 def test_base64(self) -> None: 59 base64_msg = encode.prefixed_base64(b'\xDD\xCC\xBB\xAA\x07company') 60 proto = TheMessage(message=base64_msg.encode()) 61 detokenize_fields(_DETOKENIZER, proto) 62 self.assertEqual(proto.message, b"Luke, we're gonna have company") 63 64 def test_recursive_base64(self) -> None: 65 base64_msg = encode.prefixed_base64(b'\x78\x56\x34\x12') 66 proto = TheMessage(message=base64_msg.encode()) 67 detokenize_fields(_DETOKENIZER, proto) 68 self.assertEqual(proto.message, b"This string has a recursive token") 69 70 def test_plain_text_with_prefixed_base64(self) -> None: 71 base64_msg = encode.prefixed_base64(b'\xDD\xCC\xBB\xAA\x09pancakes!') 72 proto = TheMessage(message=f'Good morning, {base64_msg}'.encode()) 73 detokenize_fields(_DETOKENIZER, proto) 74 self.assertEqual( 75 proto.message, b"Good morning, Luke, we're gonna have pancakes!" 76 ) 77 78 def test_unknown_token_not_utf8(self) -> None: 79 proto = TheMessage(message=b'\xFE\xED\xF0\x0D') 80 detokenize_fields(_DETOKENIZER, proto) 81 self.assertEqual( 82 proto.message.decode(), encode.prefixed_base64(b'\xFE\xED\xF0\x0D') 83 ) 84 85 def test_only_control_characters(self) -> None: 86 proto = TheMessage(message=b'\1\2\3\4') 87 detokenize_fields(_DETOKENIZER, proto) 88 self.assertEqual( 89 proto.message.decode(), encode.prefixed_base64(b'\1\2\3\4') 90 ) 91 92 def test_no_detokenizer_plain_text(self) -> None: 93 proto = TheMessage(message=b'boring conversation anyway!') 94 detokenize_fields(None, proto) 95 self.assertEqual(proto.message, b'boring conversation anyway!') 96 97 def test_no_detokenizer_unknown_token_not_utf8(self) -> None: 98 proto = TheMessage(message=b'\xFE\xED\xF0\x0D') 99 detokenize_fields(None, proto) 100 self.assertEqual( 101 proto.message.decode(), encode.prefixed_base64(b'\xFE\xED\xF0\x0D') 102 ) 103 104 def test_no_detokenizer_only_control_characters(self) -> None: 105 proto = TheMessage(message=b'\1\2\3\4') 106 detokenize_fields(None, proto) 107 self.assertEqual( 108 proto.message.decode(), encode.prefixed_base64(b'\1\2\3\4') 109 ) 110 111 112class TestDecodeOptionallyTokenized(unittest.TestCase): 113 """Tests optional detokenization directly.""" 114 115 def setUp(self) -> None: 116 db = tokens.Database( 117 [ 118 tokens.TokenizedStringEntry(0, 'cheese'), 119 tokens.TokenizedStringEntry(1, 'on pizza'), 120 tokens.TokenizedStringEntry(2, 'is quite good'), 121 tokens.TokenizedStringEntry(3, 'they say'), 122 ] 123 ) 124 self.detok = detokenize.Detokenizer(db) 125 self.detok_tilde_prefix = detokenize.Detokenizer(db, prefix='~') 126 127 def test_found_binary_token(self) -> None: 128 self.assertEqual( 129 'on pizza', 130 decode_optionally_tokenized(self.detok, b'\x01\x00\x00\x00'), 131 ) 132 133 def test_missing_binary_token(self) -> None: 134 self.assertEqual( 135 '$' + base64.b64encode(b'\xD5\x8A\xF9\x2A\x8A').decode(), 136 decode_optionally_tokenized(self.detok, b'\xD5\x8A\xF9\x2A\x8A'), 137 ) 138 139 def test_found_b64_token(self) -> None: 140 b64_bytes = b'$' + base64.b64encode(b'\x03\x00\x00\x00') 141 self.assertEqual( 142 'they say', decode_optionally_tokenized(self.detok, b64_bytes) 143 ) 144 145 def test_missing_b64_token(self) -> None: 146 b64_bytes = b'$' + base64.b64encode(b'\xD5\x8A\xF9\x2A\x8A') 147 self.assertEqual( 148 b64_bytes.decode(), 149 decode_optionally_tokenized(self.detok, b64_bytes), 150 ) 151 152 def test_found_alternate_prefix(self) -> None: 153 b64_bytes = b'~' + base64.b64encode(b'\x00\x00\x00\x00') 154 self.assertEqual( 155 'cheese', 156 decode_optionally_tokenized(self.detok_tilde_prefix, b64_bytes), 157 ) 158 159 def test_missing_alternate_prefix(self) -> None: 160 b64_bytes = b'~' + base64.b64encode(b'\x02\x00\x00\x00') 161 self.assertEqual( 162 b64_bytes.decode(), 163 decode_optionally_tokenized(self.detok, b64_bytes), 164 ) 165 166 def test_no_detokenizer_binary(self) -> None: 167 data = b'\x01\x00\x00\x00' 168 self.assertEqual( 169 encode.prefixed_base64(data), 170 decode_optionally_tokenized(None, data), 171 ) 172 173 def test_no_detokenizer_printable_utf8(self) -> None: 174 self.assertEqual( 175 'this\tis\r\nsome\ntext', 176 decode_optionally_tokenized(None, b'this\tis\r\nsome\ntext'), 177 ) 178 179 def test_no_detokenizer_nonprintable_utf8(self) -> None: 180 data = b'\a\0wh\nat?' 181 self.assertEqual( 182 encode.prefixed_base64(data), 183 decode_optionally_tokenized(None, data), 184 ) 185 186 187if __name__ == '__main__': 188 unittest.main() 189