xref: /aosp_15_r20/external/executorch/exir/_serialize/test/test_flatbuffer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#!/usr/bin/env fbpython
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import os
9import re
10import shutil
11import tempfile
12import unittest
13from typing import Dict, Optional, Sequence
14from unittest.mock import patch
15
16from executorch.exir._serialize import _flatbuffer
17from executorch.exir._serialize._flatbuffer import (
18    _program_json_to_flatbuffer,
19    _ResourceFiles,
20    _SchemaInfo,
21)
22
23
24def read_file(dir: str, filename: str) -> bytes:
25    """Returns the contents of the given file."""
26    with open(os.path.join(dir, filename), "rb") as fp:
27        return fp.read()
28
29
30# Fake resource files to use when testing _ResourceFiles.
31FAKE_RESOURCES: Dict[str, bytes] = {
32    "resource-1": b"resource-1 data",
33    "resource-2": b"resource-2 data",
34}
35
36
37class TestResourceFiles(unittest.TestCase):
38    def make_resource_files(self, files: Dict[str, bytes]) -> _ResourceFiles:
39        """Returns a _ResourceFiles containing the injected fake files.
40
41        Args:
42            files: Mapping of filename to contents.
43        """
44        with patch.object(
45            _flatbuffer.importlib.resources, "read_binary"
46        ) as mock_read_binary:
47            # Use the fake resource files when looking up resources.
48            mock_read_binary.side_effect = lambda _, name: files[name]
49            return _ResourceFiles(tuple(files.keys()))
50
51    def test_load_and_write(self) -> None:
52        rf: _ResourceFiles = self.make_resource_files(FAKE_RESOURCES)
53        with tempfile.TemporaryDirectory() as out_dir:
54            # Write the unmodified inputs to the filesystem.
55            rf.write_to(out_dir)
56            self.assertEqual(read_file(out_dir, "resource-1"), b"resource-1 data")
57            self.assertEqual(read_file(out_dir, "resource-2"), b"resource-2 data")
58
59    def test_load_patch_and_write(self) -> None:
60        rf: _ResourceFiles = self.make_resource_files(FAKE_RESOURCES)
61
62        # Append something to the end of each file.
63        rf.patch_files(lambda data: data + b" PATCHED")
64
65        with tempfile.TemporaryDirectory() as out_dir:
66            rf.write_to(out_dir)
67            self.assertEqual(
68                read_file(out_dir, "resource-1"), b"resource-1 data PATCHED"
69            )
70            self.assertEqual(
71                read_file(out_dir, "resource-2"), b"resource-2 data PATCHED"
72            )
73
74
75# Fake resource files to use when testing alignment-patching.
76SCHEMA_FILES: Dict[str, bytes] = {
77    "program.fbs": b"\n".join(
78        [
79            b"table Program {",
80            # Space after the colon.
81            b"  tensor_data: [ubyte] (force_align: 8); // @executorch-tensor-alignment",
82            # No spaces around the colon.
83            b"  delegate_data: [ubyte] (force_align:16); // @executorch-delegate-alignment",
84            b"  other_data: [ubyte] (force_align: 32);",
85            b"}",
86        ]
87    ),
88    "scalar_type.fbs": b"\n".join(
89        [
90            b"table ScalarType {",
91            # Spaces around the colon.
92            b"  tensor_data: [ubyte] (force_align : 8); // @executorch-tensor-alignment",
93            # Spaces between all tokens.
94            b"  delegate_data: [ubyte] ( force_align : 16 ); // @executorch-delegate-alignment",
95            b"  other_data: [ubyte] (force_align: 64);",
96            b"}",
97        ]
98    ),
99}
100
101
102# Bad alignment values; not whole powers of 2.
103BAD_ALIGNMENTS: Sequence[int] = (-1, 0, 5)
104
105
106class TestPrepareSchema(unittest.TestCase):
107    def call_prepare_schema(
108        self,
109        schema_files: Dict[str, bytes],
110        out_dir: str,
111        constant_tensor_alignment: Optional[int] = None,
112        delegate_alignment: Optional[int] = None,
113    ) -> _SchemaInfo:
114        """Calls _prepare_schema(), using `files` to get the original contents
115        of the schema files.
116        """
117        with patch.object(
118            _flatbuffer.importlib.resources, "read_binary"
119        ) as mock_read_binary:
120            # Use the fake resource files when looking up resources.
121            mock_read_binary.side_effect = lambda _, name: schema_files[name]
122            return _flatbuffer._prepare_schema(
123                out_dir=out_dir,
124                constant_tensor_alignment=constant_tensor_alignment,
125                delegate_alignment=delegate_alignment,
126            )
127
128    def test_unmodified(self) -> None:
129        with tempfile.TemporaryDirectory() as out_dir:
130            info: _SchemaInfo = self.call_prepare_schema(SCHEMA_FILES, out_dir)
131            self.assertEqual(info.root_path, os.path.join(out_dir, "program.fbs"))
132            # Files should not have been modified.
133            for fname in SCHEMA_FILES.keys():
134                self.assertEqual(read_file(out_dir, fname), SCHEMA_FILES[fname])
135            # Max alignment should be the largest value in the input.
136            self.assertEqual(info.max_alignment, 64)
137
138    def test_update_tensor_alignment(self) -> None:
139        with tempfile.TemporaryDirectory() as out_dir:
140            info: _SchemaInfo = self.call_prepare_schema(
141                SCHEMA_FILES, out_dir, constant_tensor_alignment=128
142            )
143            self.assertEqual(info.root_path, os.path.join(out_dir, "program.fbs"))
144            # Only the tensor alignment lines should have been modified.
145            self.assertEqual(
146                read_file(out_dir, "program.fbs"),
147                b"\n".join(
148                    [
149                        b"table Program {",
150                        # Now 128:
151                        b"  tensor_data: [ubyte] (force_align: 128); // @executorch-tensor-alignment",
152                        b"  delegate_data: [ubyte] (force_align:16); // @executorch-delegate-alignment",
153                        b"  other_data: [ubyte] (force_align: 32);",
154                        b"}",
155                    ]
156                ),
157            )
158            self.assertEqual(
159                read_file(out_dir, "scalar_type.fbs"),
160                b"\n".join(
161                    [
162                        b"table ScalarType {",
163                        # Now 128, and reformatted:
164                        b"  tensor_data: [ubyte] (force_align: 128); // @executorch-tensor-alignment",
165                        b"  delegate_data: [ubyte] ( force_align : 16 ); // @executorch-delegate-alignment",
166                        b"  other_data: [ubyte] (force_align: 64);",
167                        b"}",
168                    ]
169                ),
170            )
171            # Max alignment should reflect this change.
172            self.assertEqual(info.max_alignment, 128)
173
174    def test_update_delegate_alignment(self) -> None:
175        with tempfile.TemporaryDirectory() as out_dir:
176            info: _SchemaInfo = self.call_prepare_schema(
177                SCHEMA_FILES, out_dir, delegate_alignment=256
178            )
179            self.assertEqual(info.root_path, os.path.join(out_dir, "program.fbs"))
180            # Only the delegate alignment lines should have been modified.
181            self.assertEqual(
182                read_file(out_dir, "program.fbs"),
183                b"\n".join(
184                    [
185                        b"table Program {",
186                        b"  tensor_data: [ubyte] (force_align: 8); // @executorch-tensor-alignment",
187                        # Now 256:
188                        b"  delegate_data: [ubyte] (force_align: 256); // @executorch-delegate-alignment",
189                        b"  other_data: [ubyte] (force_align: 32);",
190                        b"}",
191                    ]
192                ),
193            )
194            self.assertEqual(
195                read_file(out_dir, "scalar_type.fbs"),
196                b"\n".join(
197                    [
198                        b"table ScalarType {",
199                        b"  tensor_data: [ubyte] (force_align : 8); // @executorch-tensor-alignment",
200                        # Now 256, and reformatted:
201                        b"  delegate_data: [ubyte] (force_align: 256); // @executorch-delegate-alignment",
202                        b"  other_data: [ubyte] (force_align: 64);",
203                        b"}",
204                    ]
205                ),
206            )
207            # Max alignment should reflect this change.
208            self.assertEqual(info.max_alignment, 256)
209
210    def test_update_tensor_and_delegate_alignment(self) -> None:
211        with tempfile.TemporaryDirectory() as out_dir:
212            info: _SchemaInfo = self.call_prepare_schema(
213                SCHEMA_FILES,
214                out_dir,
215                constant_tensor_alignment=1,
216                delegate_alignment=2,
217            )
218            self.assertEqual(info.root_path, os.path.join(out_dir, "program.fbs"))
219            # Only the delegate alignment lines should have been modified.
220            self.assertEqual(
221                read_file(out_dir, "program.fbs"),
222                b"\n".join(
223                    [
224                        b"table Program {",
225                        # Now 1:
226                        b"  tensor_data: [ubyte] (force_align: 1); // @executorch-tensor-alignment",
227                        # Now 2:
228                        b"  delegate_data: [ubyte] (force_align: 2); // @executorch-delegate-alignment",
229                        b"  other_data: [ubyte] (force_align: 32);",
230                        b"}",
231                    ]
232                ),
233            )
234            self.assertEqual(
235                read_file(out_dir, "scalar_type.fbs"),
236                b"\n".join(
237                    [
238                        b"table ScalarType {",
239                        # Now 1, and reformatted:
240                        b"  tensor_data: [ubyte] (force_align: 1); // @executorch-tensor-alignment",
241                        # Now 2, and reformatted:
242                        b"  delegate_data: [ubyte] (force_align: 2); // @executorch-delegate-alignment",
243                        b"  other_data: [ubyte] (force_align: 64);",
244                        b"}",
245                    ]
246                ),
247            )
248            self.assertEqual(info.max_alignment, 64)
249
250    def test_bad_tensor_alignment_fails(self) -> None:
251        with tempfile.TemporaryDirectory() as out_dir:
252            for bad_alignment in BAD_ALIGNMENTS:
253                # subTest will create a different top-level test entry for each
254                # value, whose full names have a suffix like "(bad_alignment=5)".
255                with self.subTest(bad_alignment=bad_alignment):
256                    with self.assertRaises(ValueError):
257                        self.call_prepare_schema(
258                            SCHEMA_FILES,
259                            out_dir,
260                            constant_tensor_alignment=bad_alignment,
261                        )
262
263    def test_bad_delegate_alignment_fails(self) -> None:
264        with tempfile.TemporaryDirectory() as out_dir:
265            for bad_alignment in BAD_ALIGNMENTS:
266                # subTest will create a different top-level test entry for each
267                # value, whose full names have a suffix like "(bad_alignment=5)".
268                with self.subTest(bad_alignment=bad_alignment):
269                    with self.assertRaises(ValueError):
270                        self.call_prepare_schema(
271                            SCHEMA_FILES,
272                            out_dir,
273                            delegate_alignment=bad_alignment,
274                        )
275
276
277class TestProgramJsonToFlatbuffer(unittest.TestCase):
278    @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: "1"})
279    def test_save_json_on_failure(self) -> None:
280        err_msg: Optional[str] = None
281        try:
282            _program_json_to_flatbuffer("} some bad json {")
283            self.fail("Should have raised an exception")
284        except RuntimeError as err:
285            err_msg = err.args[0]
286
287        self.assertIsNotNone(err_msg)
288        match = re.search(r"Moved input files to '(.*?)'", err_msg)
289        self.assertTrue(match, msg=f"Unexpected error message: {err_msg}")
290        path = match.group(1)
291
292        files = frozenset(os.listdir(path))
293        # Delete the files otherwise they'll accumulate every time the
294        # test is run.
295        shutil.rmtree(path)
296        # Check for a couple of the files that should be there.
297        self.assertIn("data.json", files)
298        self.assertIn("program.fbs", files)
299
300    @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: "1"})
301    def test_unable_to_save_json_on_failure(self) -> None:
302        err_msg: Optional[str] = None
303        try:
304            with patch.object(
305                _flatbuffer.shutil,
306                "move",
307                side_effect=Exception("shutil.move mock failure"),
308            ):
309                _program_json_to_flatbuffer("} some bad json {")
310            self.fail("Should have raised an exception")
311        except RuntimeError as err:
312            err_msg = err.args[0]
313
314        self.assertIsNotNone(err_msg)
315        self.assertIn("Failed to save input files", err_msg)
316
317    @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: ""})
318    def test_no_save_json_on_failure(self) -> None:
319        err_msg: Optional[str] = None
320        try:
321            _program_json_to_flatbuffer("} some bad json {")
322            self.fail("Should have raised an exception")
323        except RuntimeError as err:
324            err_msg = err.args[0]
325
326        self.assertIsNotNone(err_msg)
327        self.assertIn(
328            f"Set {_flatbuffer._SAVE_FLATC_ENV}=1 to save input files", err_msg
329        )
330        self.assertNotIn("Moved input files", err_msg)
331        self.assertNotIn("Failed to save input files", err_msg)
332