1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Unit tests for pw_software_update/update_bundle.py.""" 15 16from pathlib import Path 17import tempfile 18import unittest 19 20from pw_software_update import update_bundle 21from pw_software_update.tuf_pb2 import SignedRootMetadata, TargetsMetadata 22 23 24class TargetsFromDirectoryTest(unittest.TestCase): 25 """Test turning a directory into TUF targets.""" 26 27 def test_excludes(self): 28 """Checks that excludes are excluded.""" 29 with tempfile.TemporaryDirectory() as tempdir_name: 30 temp_root = Path(tempdir_name) 31 foo_path = temp_root / 'foo.bin' 32 bar_path = temp_root / 'bar.bin' 33 baz_path = temp_root / 'baz.bin' 34 qux_path = temp_root / 'qux.exe' 35 for path in (foo_path, bar_path, baz_path, qux_path): 36 path.touch() 37 38 targets = update_bundle.targets_from_directory( 39 temp_root, exclude=(Path('foo.bin'), Path('baz.bin')) 40 ) 41 42 self.assertNotIn('foo.bin', targets) 43 self.assertEqual(bar_path, targets['bar.bin']) 44 self.assertNotIn('baz.bin', targets) 45 self.assertEqual(qux_path, targets['qux.exe']) 46 47 def test_excludes_and_remapping(self): 48 """Checks that remapping works, even in combination with excludes.""" 49 with tempfile.TemporaryDirectory() as tempdir_name: 50 temp_root = Path(tempdir_name) 51 foo_path = temp_root / 'foo.bin' 52 bar_path = temp_root / 'bar.bin' 53 baz_path = temp_root / 'baz.bin' 54 qux_path = temp_root / 'qux.exe' 55 remap_paths = { 56 Path('foo.bin'): 'main', 57 Path('bar.bin'): 'backup', 58 Path('baz.bin'): 'tertiary', 59 } 60 for path in (foo_path, bar_path, baz_path, qux_path): 61 path.touch() 62 63 targets = update_bundle.targets_from_directory( 64 temp_root, exclude=(Path('qux.exe'),), remap_paths=remap_paths 65 ) 66 67 self.assertEqual(foo_path, targets['main']) 68 self.assertEqual(bar_path, targets['backup']) 69 self.assertEqual(baz_path, targets['tertiary']) 70 self.assertNotIn('qux.exe', targets) 71 72 def test_incomplete_remapping_logs(self): 73 """Checks that incomplete remappings log warnings.""" 74 with tempfile.TemporaryDirectory() as tempdir_name: 75 temp_root = Path(tempdir_name) 76 foo_path = temp_root / 'foo.bin' 77 bar_path = temp_root / 'bar.bin' 78 foo_path.touch() 79 bar_path.touch() 80 remap_paths = {Path('foo.bin'): 'main'} 81 82 with self.assertLogs(level='WARNING') as log: 83 update_bundle.targets_from_directory( 84 temp_root, 85 exclude=(Path('qux.exe'),), 86 remap_paths=remap_paths, 87 ) 88 89 self.assertIn( 90 'Some remaps defined, but not "bar.bin"', log.output[0] 91 ) 92 93 def test_remap_of_missing_file(self): 94 """Checks that remapping a missing file raises an error.""" 95 with tempfile.TemporaryDirectory() as tempdir_name: 96 temp_root = Path(tempdir_name) 97 foo_path = temp_root / 'foo.bin' 98 foo_path.touch() 99 remap_paths = { 100 Path('foo.bin'): 'main', 101 Path('bar.bin'): 'backup', 102 } 103 104 with self.assertRaises(FileNotFoundError): 105 update_bundle.targets_from_directory( 106 temp_root, remap_paths=remap_paths 107 ) 108 109 110class GenUnsignedUpdateBundleTest(unittest.TestCase): 111 """Test the generation of unsigned update bundles.""" 112 113 def test_bundle_generation(self): 114 """Tests basic creation of an UpdateBundle.""" 115 with tempfile.TemporaryDirectory() as tempdir_name: 116 temp_root = Path(tempdir_name) 117 foo_path = temp_root / 'foo.bin' 118 bar_path = temp_root / 'bar.bin' 119 baz_path = temp_root / 'baz.bin' 120 qux_path = temp_root / 'subdir' / 'qux.exe' 121 foo_bytes = b'\xf0\x0b\xa4' 122 bar_bytes = b'\x0b\xa4\x99' 123 baz_bytes = b'\xba\x59\x06' 124 qux_bytes = b'\x8a\xf3\x12' 125 foo_path.write_bytes(foo_bytes) 126 bar_path.write_bytes(bar_bytes) 127 baz_path.write_bytes(baz_bytes) 128 (temp_root / 'subdir').mkdir() 129 qux_path.write_bytes(qux_bytes) 130 targets = { 131 foo_path: 'foo', 132 bar_path: 'bar', 133 baz_path: 'baz', 134 qux_path: 'qux', 135 } 136 serialized_root_metadata_bytes = b'\x12\x34\x56\x78' 137 138 bundle = update_bundle.gen_unsigned_update_bundle( 139 targets, 140 targets_metadata_version=42, 141 root_metadata=SignedRootMetadata( 142 serialized_root_metadata=serialized_root_metadata_bytes 143 ), 144 ) 145 146 self.assertEqual(foo_bytes, bundle.target_payloads['foo']) 147 self.assertEqual(bar_bytes, bundle.target_payloads['bar']) 148 self.assertEqual(baz_bytes, bundle.target_payloads['baz']) 149 self.assertEqual(qux_bytes, bundle.target_payloads['qux']) 150 targets_metadata = TargetsMetadata.FromString( 151 bundle.targets_metadata['targets'].serialized_targets_metadata 152 ) 153 self.assertEqual(targets_metadata.common_metadata.version, 42) 154 self.assertEqual( 155 serialized_root_metadata_bytes, 156 bundle.root_metadata.serialized_root_metadata, 157 ) 158 159 def test_persist_to_disk(self): 160 """Tests persisting the TUF repo to disk for debugging""" 161 with tempfile.TemporaryDirectory() as tempdir_name: 162 temp_root = Path(tempdir_name) 163 foo_path = temp_root / 'foo.bin' 164 bar_path = temp_root / 'bar.bin' 165 baz_path = temp_root / 'baz.bin' 166 qux_path = temp_root / 'subdir' / 'qux.exe' 167 foo_bytes = b'\xf0\x0b\xa4' 168 bar_bytes = b'\x0b\xa4\x99' 169 baz_bytes = b'\xba\x59\x06' 170 qux_bytes = b'\x8a\xf3\x12' 171 foo_path.write_bytes(foo_bytes) 172 bar_path.write_bytes(bar_bytes) 173 baz_path.write_bytes(baz_bytes) 174 (temp_root / 'subdir').mkdir() 175 qux_path.write_bytes(qux_bytes) 176 targets = { 177 foo_path: 'foo', 178 bar_path: 'bar', 179 baz_path: 'baz', 180 qux_path: 'subdir/qux', 181 } 182 persist_path = temp_root / 'persisted' 183 184 update_bundle.gen_unsigned_update_bundle( 185 targets, persist=persist_path 186 ) 187 188 self.assertEqual(foo_bytes, (persist_path / 'foo').read_bytes()) 189 self.assertEqual(bar_bytes, (persist_path / 'bar').read_bytes()) 190 self.assertEqual(baz_bytes, (persist_path / 'baz').read_bytes()) 191 self.assertEqual( 192 qux_bytes, (persist_path / 'subdir' / 'qux').read_bytes() 193 ) 194 195 196class ParseTargetArgTest(unittest.TestCase): 197 """Test the parsing of target argument strings.""" 198 199 def test_valid_arg(self): 200 """Checks that valid remap strings are parsed correctly.""" 201 file_path, target_name = update_bundle.parse_target_arg( 202 'foo.bin > main' 203 ) 204 205 self.assertEqual(Path('foo.bin'), file_path) 206 self.assertEqual('main', target_name) 207 208 def test_invalid_arg_raises(self): 209 """Checks that invalid remap string raise an error.""" 210 with self.assertRaises(ValueError): 211 update_bundle.parse_target_arg('foo.bin main') 212 213 214if __name__ == '__main__': 215 unittest.main() 216