1#!/usr/bin/env fbpython 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import os 9import re 10import shutil 11import tempfile 12import unittest 13from typing import Dict, Optional, Sequence 14from unittest.mock import patch 15 16from executorch.exir._serialize import _flatbuffer 17from executorch.exir._serialize._flatbuffer import ( 18 _program_json_to_flatbuffer, 19 _ResourceFiles, 20 _SchemaInfo, 21) 22 23 24def read_file(dir: str, filename: str) -> bytes: 25 """Returns the contents of the given file.""" 26 with open(os.path.join(dir, filename), "rb") as fp: 27 return fp.read() 28 29 30# Fake resource files to use when testing _ResourceFiles. 31FAKE_RESOURCES: Dict[str, bytes] = { 32 "resource-1": b"resource-1 data", 33 "resource-2": b"resource-2 data", 34} 35 36 37class TestResourceFiles(unittest.TestCase): 38 def make_resource_files(self, files: Dict[str, bytes]) -> _ResourceFiles: 39 """Returns a _ResourceFiles containing the injected fake files. 40 41 Args: 42 files: Mapping of filename to contents. 43 """ 44 with patch.object( 45 _flatbuffer.importlib.resources, "read_binary" 46 ) as mock_read_binary: 47 # Use the fake resource files when looking up resources. 48 mock_read_binary.side_effect = lambda _, name: files[name] 49 return _ResourceFiles(tuple(files.keys())) 50 51 def test_load_and_write(self) -> None: 52 rf: _ResourceFiles = self.make_resource_files(FAKE_RESOURCES) 53 with tempfile.TemporaryDirectory() as out_dir: 54 # Write the unmodified inputs to the filesystem. 55 rf.write_to(out_dir) 56 self.assertEqual(read_file(out_dir, "resource-1"), b"resource-1 data") 57 self.assertEqual(read_file(out_dir, "resource-2"), b"resource-2 data") 58 59 def test_load_patch_and_write(self) -> None: 60 rf: _ResourceFiles = self.make_resource_files(FAKE_RESOURCES) 61 62 # Append something to the end of each file. 63 rf.patch_files(lambda data: data + b" PATCHED") 64 65 with tempfile.TemporaryDirectory() as out_dir: 66 rf.write_to(out_dir) 67 self.assertEqual( 68 read_file(out_dir, "resource-1"), b"resource-1 data PATCHED" 69 ) 70 self.assertEqual( 71 read_file(out_dir, "resource-2"), b"resource-2 data PATCHED" 72 ) 73 74 75# Fake resource files to use when testing alignment-patching. 76SCHEMA_FILES: Dict[str, bytes] = { 77 "program.fbs": b"\n".join( 78 [ 79 b"table Program {", 80 # Space after the colon. 81 b" tensor_data: [ubyte] (force_align: 8); // @executorch-tensor-alignment", 82 # No spaces around the colon. 83 b" delegate_data: [ubyte] (force_align:16); // @executorch-delegate-alignment", 84 b" other_data: [ubyte] (force_align: 32);", 85 b"}", 86 ] 87 ), 88 "scalar_type.fbs": b"\n".join( 89 [ 90 b"table ScalarType {", 91 # Spaces around the colon. 92 b" tensor_data: [ubyte] (force_align : 8); // @executorch-tensor-alignment", 93 # Spaces between all tokens. 94 b" delegate_data: [ubyte] ( force_align : 16 ); // @executorch-delegate-alignment", 95 b" other_data: [ubyte] (force_align: 64);", 96 b"}", 97 ] 98 ), 99} 100 101 102# Bad alignment values; not whole powers of 2. 103BAD_ALIGNMENTS: Sequence[int] = (-1, 0, 5) 104 105 106class TestPrepareSchema(unittest.TestCase): 107 def call_prepare_schema( 108 self, 109 schema_files: Dict[str, bytes], 110 out_dir: str, 111 constant_tensor_alignment: Optional[int] = None, 112 delegate_alignment: Optional[int] = None, 113 ) -> _SchemaInfo: 114 """Calls _prepare_schema(), using `files` to get the original contents 115 of the schema files. 116 """ 117 with patch.object( 118 _flatbuffer.importlib.resources, "read_binary" 119 ) as mock_read_binary: 120 # Use the fake resource files when looking up resources. 121 mock_read_binary.side_effect = lambda _, name: schema_files[name] 122 return _flatbuffer._prepare_schema( 123 out_dir=out_dir, 124 constant_tensor_alignment=constant_tensor_alignment, 125 delegate_alignment=delegate_alignment, 126 ) 127 128 def test_unmodified(self) -> None: 129 with tempfile.TemporaryDirectory() as out_dir: 130 info: _SchemaInfo = self.call_prepare_schema(SCHEMA_FILES, out_dir) 131 self.assertEqual(info.root_path, os.path.join(out_dir, "program.fbs")) 132 # Files should not have been modified. 133 for fname in SCHEMA_FILES.keys(): 134 self.assertEqual(read_file(out_dir, fname), SCHEMA_FILES[fname]) 135 # Max alignment should be the largest value in the input. 136 self.assertEqual(info.max_alignment, 64) 137 138 def test_update_tensor_alignment(self) -> None: 139 with tempfile.TemporaryDirectory() as out_dir: 140 info: _SchemaInfo = self.call_prepare_schema( 141 SCHEMA_FILES, out_dir, constant_tensor_alignment=128 142 ) 143 self.assertEqual(info.root_path, os.path.join(out_dir, "program.fbs")) 144 # Only the tensor alignment lines should have been modified. 145 self.assertEqual( 146 read_file(out_dir, "program.fbs"), 147 b"\n".join( 148 [ 149 b"table Program {", 150 # Now 128: 151 b" tensor_data: [ubyte] (force_align: 128); // @executorch-tensor-alignment", 152 b" delegate_data: [ubyte] (force_align:16); // @executorch-delegate-alignment", 153 b" other_data: [ubyte] (force_align: 32);", 154 b"}", 155 ] 156 ), 157 ) 158 self.assertEqual( 159 read_file(out_dir, "scalar_type.fbs"), 160 b"\n".join( 161 [ 162 b"table ScalarType {", 163 # Now 128, and reformatted: 164 b" tensor_data: [ubyte] (force_align: 128); // @executorch-tensor-alignment", 165 b" delegate_data: [ubyte] ( force_align : 16 ); // @executorch-delegate-alignment", 166 b" other_data: [ubyte] (force_align: 64);", 167 b"}", 168 ] 169 ), 170 ) 171 # Max alignment should reflect this change. 172 self.assertEqual(info.max_alignment, 128) 173 174 def test_update_delegate_alignment(self) -> None: 175 with tempfile.TemporaryDirectory() as out_dir: 176 info: _SchemaInfo = self.call_prepare_schema( 177 SCHEMA_FILES, out_dir, delegate_alignment=256 178 ) 179 self.assertEqual(info.root_path, os.path.join(out_dir, "program.fbs")) 180 # Only the delegate alignment lines should have been modified. 181 self.assertEqual( 182 read_file(out_dir, "program.fbs"), 183 b"\n".join( 184 [ 185 b"table Program {", 186 b" tensor_data: [ubyte] (force_align: 8); // @executorch-tensor-alignment", 187 # Now 256: 188 b" delegate_data: [ubyte] (force_align: 256); // @executorch-delegate-alignment", 189 b" other_data: [ubyte] (force_align: 32);", 190 b"}", 191 ] 192 ), 193 ) 194 self.assertEqual( 195 read_file(out_dir, "scalar_type.fbs"), 196 b"\n".join( 197 [ 198 b"table ScalarType {", 199 b" tensor_data: [ubyte] (force_align : 8); // @executorch-tensor-alignment", 200 # Now 256, and reformatted: 201 b" delegate_data: [ubyte] (force_align: 256); // @executorch-delegate-alignment", 202 b" other_data: [ubyte] (force_align: 64);", 203 b"}", 204 ] 205 ), 206 ) 207 # Max alignment should reflect this change. 208 self.assertEqual(info.max_alignment, 256) 209 210 def test_update_tensor_and_delegate_alignment(self) -> None: 211 with tempfile.TemporaryDirectory() as out_dir: 212 info: _SchemaInfo = self.call_prepare_schema( 213 SCHEMA_FILES, 214 out_dir, 215 constant_tensor_alignment=1, 216 delegate_alignment=2, 217 ) 218 self.assertEqual(info.root_path, os.path.join(out_dir, "program.fbs")) 219 # Only the delegate alignment lines should have been modified. 220 self.assertEqual( 221 read_file(out_dir, "program.fbs"), 222 b"\n".join( 223 [ 224 b"table Program {", 225 # Now 1: 226 b" tensor_data: [ubyte] (force_align: 1); // @executorch-tensor-alignment", 227 # Now 2: 228 b" delegate_data: [ubyte] (force_align: 2); // @executorch-delegate-alignment", 229 b" other_data: [ubyte] (force_align: 32);", 230 b"}", 231 ] 232 ), 233 ) 234 self.assertEqual( 235 read_file(out_dir, "scalar_type.fbs"), 236 b"\n".join( 237 [ 238 b"table ScalarType {", 239 # Now 1, and reformatted: 240 b" tensor_data: [ubyte] (force_align: 1); // @executorch-tensor-alignment", 241 # Now 2, and reformatted: 242 b" delegate_data: [ubyte] (force_align: 2); // @executorch-delegate-alignment", 243 b" other_data: [ubyte] (force_align: 64);", 244 b"}", 245 ] 246 ), 247 ) 248 self.assertEqual(info.max_alignment, 64) 249 250 def test_bad_tensor_alignment_fails(self) -> None: 251 with tempfile.TemporaryDirectory() as out_dir: 252 for bad_alignment in BAD_ALIGNMENTS: 253 # subTest will create a different top-level test entry for each 254 # value, whose full names have a suffix like "(bad_alignment=5)". 255 with self.subTest(bad_alignment=bad_alignment): 256 with self.assertRaises(ValueError): 257 self.call_prepare_schema( 258 SCHEMA_FILES, 259 out_dir, 260 constant_tensor_alignment=bad_alignment, 261 ) 262 263 def test_bad_delegate_alignment_fails(self) -> None: 264 with tempfile.TemporaryDirectory() as out_dir: 265 for bad_alignment in BAD_ALIGNMENTS: 266 # subTest will create a different top-level test entry for each 267 # value, whose full names have a suffix like "(bad_alignment=5)". 268 with self.subTest(bad_alignment=bad_alignment): 269 with self.assertRaises(ValueError): 270 self.call_prepare_schema( 271 SCHEMA_FILES, 272 out_dir, 273 delegate_alignment=bad_alignment, 274 ) 275 276 277class TestProgramJsonToFlatbuffer(unittest.TestCase): 278 @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: "1"}) 279 def test_save_json_on_failure(self) -> None: 280 err_msg: Optional[str] = None 281 try: 282 _program_json_to_flatbuffer("} some bad json {") 283 self.fail("Should have raised an exception") 284 except RuntimeError as err: 285 err_msg = err.args[0] 286 287 self.assertIsNotNone(err_msg) 288 match = re.search(r"Moved input files to '(.*?)'", err_msg) 289 self.assertTrue(match, msg=f"Unexpected error message: {err_msg}") 290 path = match.group(1) 291 292 files = frozenset(os.listdir(path)) 293 # Delete the files otherwise they'll accumulate every time the 294 # test is run. 295 shutil.rmtree(path) 296 # Check for a couple of the files that should be there. 297 self.assertIn("data.json", files) 298 self.assertIn("program.fbs", files) 299 300 @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: "1"}) 301 def test_unable_to_save_json_on_failure(self) -> None: 302 err_msg: Optional[str] = None 303 try: 304 with patch.object( 305 _flatbuffer.shutil, 306 "move", 307 side_effect=Exception("shutil.move mock failure"), 308 ): 309 _program_json_to_flatbuffer("} some bad json {") 310 self.fail("Should have raised an exception") 311 except RuntimeError as err: 312 err_msg = err.args[0] 313 314 self.assertIsNotNone(err_msg) 315 self.assertIn("Failed to save input files", err_msg) 316 317 @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: ""}) 318 def test_no_save_json_on_failure(self) -> None: 319 err_msg: Optional[str] = None 320 try: 321 _program_json_to_flatbuffer("} some bad json {") 322 self.fail("Should have raised an exception") 323 except RuntimeError as err: 324 err_msg = err.args[0] 325 326 self.assertIsNotNone(err_msg) 327 self.assertIn( 328 f"Set {_flatbuffer._SAVE_FLATC_ENV}=1 to save input files", err_msg 329 ) 330 self.assertNotIn("Moved input files", err_msg) 331 self.assertNotIn("Failed to save input files", err_msg) 332