xref: /aosp_15_r20/external/pigweed/pw_software_update/py/update_bundle_test.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Unit tests for pw_software_update/update_bundle.py."""
15
16from pathlib import Path
17import tempfile
18import unittest
19
20from pw_software_update import update_bundle
21from pw_software_update.tuf_pb2 import SignedRootMetadata, TargetsMetadata
22
23
24class TargetsFromDirectoryTest(unittest.TestCase):
25    """Test turning a directory into TUF targets."""
26
27    def test_excludes(self):
28        """Checks that excludes are excluded."""
29        with tempfile.TemporaryDirectory() as tempdir_name:
30            temp_root = Path(tempdir_name)
31            foo_path = temp_root / 'foo.bin'
32            bar_path = temp_root / 'bar.bin'
33            baz_path = temp_root / 'baz.bin'
34            qux_path = temp_root / 'qux.exe'
35            for path in (foo_path, bar_path, baz_path, qux_path):
36                path.touch()
37
38            targets = update_bundle.targets_from_directory(
39                temp_root, exclude=(Path('foo.bin'), Path('baz.bin'))
40            )
41
42            self.assertNotIn('foo.bin', targets)
43            self.assertEqual(bar_path, targets['bar.bin'])
44            self.assertNotIn('baz.bin', targets)
45            self.assertEqual(qux_path, targets['qux.exe'])
46
47    def test_excludes_and_remapping(self):
48        """Checks that remapping works, even in combination with excludes."""
49        with tempfile.TemporaryDirectory() as tempdir_name:
50            temp_root = Path(tempdir_name)
51            foo_path = temp_root / 'foo.bin'
52            bar_path = temp_root / 'bar.bin'
53            baz_path = temp_root / 'baz.bin'
54            qux_path = temp_root / 'qux.exe'
55            remap_paths = {
56                Path('foo.bin'): 'main',
57                Path('bar.bin'): 'backup',
58                Path('baz.bin'): 'tertiary',
59            }
60            for path in (foo_path, bar_path, baz_path, qux_path):
61                path.touch()
62
63            targets = update_bundle.targets_from_directory(
64                temp_root, exclude=(Path('qux.exe'),), remap_paths=remap_paths
65            )
66
67            self.assertEqual(foo_path, targets['main'])
68            self.assertEqual(bar_path, targets['backup'])
69            self.assertEqual(baz_path, targets['tertiary'])
70            self.assertNotIn('qux.exe', targets)
71
72    def test_incomplete_remapping_logs(self):
73        """Checks that incomplete remappings log warnings."""
74        with tempfile.TemporaryDirectory() as tempdir_name:
75            temp_root = Path(tempdir_name)
76            foo_path = temp_root / 'foo.bin'
77            bar_path = temp_root / 'bar.bin'
78            foo_path.touch()
79            bar_path.touch()
80            remap_paths = {Path('foo.bin'): 'main'}
81
82            with self.assertLogs(level='WARNING') as log:
83                update_bundle.targets_from_directory(
84                    temp_root,
85                    exclude=(Path('qux.exe'),),
86                    remap_paths=remap_paths,
87                )
88
89                self.assertIn(
90                    'Some remaps defined, but not "bar.bin"', log.output[0]
91                )
92
93    def test_remap_of_missing_file(self):
94        """Checks that remapping a missing file raises an error."""
95        with tempfile.TemporaryDirectory() as tempdir_name:
96            temp_root = Path(tempdir_name)
97            foo_path = temp_root / 'foo.bin'
98            foo_path.touch()
99            remap_paths = {
100                Path('foo.bin'): 'main',
101                Path('bar.bin'): 'backup',
102            }
103
104            with self.assertRaises(FileNotFoundError):
105                update_bundle.targets_from_directory(
106                    temp_root, remap_paths=remap_paths
107                )
108
109
110class GenUnsignedUpdateBundleTest(unittest.TestCase):
111    """Test the generation of unsigned update bundles."""
112
113    def test_bundle_generation(self):
114        """Tests basic creation of an UpdateBundle."""
115        with tempfile.TemporaryDirectory() as tempdir_name:
116            temp_root = Path(tempdir_name)
117            foo_path = temp_root / 'foo.bin'
118            bar_path = temp_root / 'bar.bin'
119            baz_path = temp_root / 'baz.bin'
120            qux_path = temp_root / 'subdir' / 'qux.exe'
121            foo_bytes = b'\xf0\x0b\xa4'
122            bar_bytes = b'\x0b\xa4\x99'
123            baz_bytes = b'\xba\x59\x06'
124            qux_bytes = b'\x8a\xf3\x12'
125            foo_path.write_bytes(foo_bytes)
126            bar_path.write_bytes(bar_bytes)
127            baz_path.write_bytes(baz_bytes)
128            (temp_root / 'subdir').mkdir()
129            qux_path.write_bytes(qux_bytes)
130            targets = {
131                foo_path: 'foo',
132                bar_path: 'bar',
133                baz_path: 'baz',
134                qux_path: 'qux',
135            }
136            serialized_root_metadata_bytes = b'\x12\x34\x56\x78'
137
138            bundle = update_bundle.gen_unsigned_update_bundle(
139                targets,
140                targets_metadata_version=42,
141                root_metadata=SignedRootMetadata(
142                    serialized_root_metadata=serialized_root_metadata_bytes
143                ),
144            )
145
146            self.assertEqual(foo_bytes, bundle.target_payloads['foo'])
147            self.assertEqual(bar_bytes, bundle.target_payloads['bar'])
148            self.assertEqual(baz_bytes, bundle.target_payloads['baz'])
149            self.assertEqual(qux_bytes, bundle.target_payloads['qux'])
150            targets_metadata = TargetsMetadata.FromString(
151                bundle.targets_metadata['targets'].serialized_targets_metadata
152            )
153            self.assertEqual(targets_metadata.common_metadata.version, 42)
154            self.assertEqual(
155                serialized_root_metadata_bytes,
156                bundle.root_metadata.serialized_root_metadata,
157            )
158
159    def test_persist_to_disk(self):
160        """Tests persisting the TUF repo to disk for debugging"""
161        with tempfile.TemporaryDirectory() as tempdir_name:
162            temp_root = Path(tempdir_name)
163            foo_path = temp_root / 'foo.bin'
164            bar_path = temp_root / 'bar.bin'
165            baz_path = temp_root / 'baz.bin'
166            qux_path = temp_root / 'subdir' / 'qux.exe'
167            foo_bytes = b'\xf0\x0b\xa4'
168            bar_bytes = b'\x0b\xa4\x99'
169            baz_bytes = b'\xba\x59\x06'
170            qux_bytes = b'\x8a\xf3\x12'
171            foo_path.write_bytes(foo_bytes)
172            bar_path.write_bytes(bar_bytes)
173            baz_path.write_bytes(baz_bytes)
174            (temp_root / 'subdir').mkdir()
175            qux_path.write_bytes(qux_bytes)
176            targets = {
177                foo_path: 'foo',
178                bar_path: 'bar',
179                baz_path: 'baz',
180                qux_path: 'subdir/qux',
181            }
182            persist_path = temp_root / 'persisted'
183
184            update_bundle.gen_unsigned_update_bundle(
185                targets, persist=persist_path
186            )
187
188            self.assertEqual(foo_bytes, (persist_path / 'foo').read_bytes())
189            self.assertEqual(bar_bytes, (persist_path / 'bar').read_bytes())
190            self.assertEqual(baz_bytes, (persist_path / 'baz').read_bytes())
191            self.assertEqual(
192                qux_bytes, (persist_path / 'subdir' / 'qux').read_bytes()
193            )
194
195
196class ParseTargetArgTest(unittest.TestCase):
197    """Test the parsing of target argument strings."""
198
199    def test_valid_arg(self):
200        """Checks that valid remap strings are parsed correctly."""
201        file_path, target_name = update_bundle.parse_target_arg(
202            'foo.bin > main'
203        )
204
205        self.assertEqual(Path('foo.bin'), file_path)
206        self.assertEqual('main', target_name)
207
208    def test_invalid_arg_raises(self):
209        """Checks that invalid remap string raise an error."""
210        with self.assertRaises(ValueError):
211            update_bundle.parse_target_arg('foo.bin main')
212
213
214if __name__ == '__main__':
215    unittest.main()
216