1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import json 10import os 11import pickle 12from dataclasses import dataclass 13from typing import BinaryIO, Dict, IO, List, Optional, Union 14from zipfile import BadZipFile, ZipFile 15 16from executorch import exir 17from executorch.devtools.bundled_program.core import BundledProgram 18 19from executorch.devtools.bundled_program.schema.bundled_program_schema import Value 20from executorch.exir import ( 21 EdgeProgramManager, 22 ExecutorchProgram, 23 ExecutorchProgramManager, 24 ExirExportedProgram, 25 ExportedProgram, 26) 27from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap 28 29from executorch.exir.serde.export_serialize import SerializedArtifact 30from executorch.exir.serde.serialize import deserialize, serialize 31 32ProgramOutput = List[Value] 33 34try: 35 # breaking change introduced in python 3.11 36 # pyre-ignore 37 from enum import StrEnum 38except ImportError: 39 from enum import Enum 40 41 class StrEnum(str, Enum): 42 pass 43 44 45class ETRecordReservedFileNames(StrEnum): 46 ETRECORD_IDENTIFIER = "ETRECORD_V0" 47 EDGE_DIALECT_EXPORTED_PROGRAM = "edge_dialect_exported_program" 48 ET_DIALECT_GRAPH_MODULE = "et_dialect_graph_module" 49 DEBUG_HANDLE_MAP_NAME = "debug_handle_map" 50 DELEGATE_MAP_NAME = "delegate_map" 51 REFERENCE_OUTPUTS = "reference_outputs" 52 53 54@dataclass 55class ETRecord: 56 edge_dialect_program: Optional[ExportedProgram] = None 57 graph_map: Optional[Dict[str, ExportedProgram]] = None 58 _debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None 59 _delegate_map: Optional[ 60 Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]] 61 ] = None 62 _reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None 63 64 65def _handle_exported_program( 66 etrecord_zip: ZipFile, module_name: str, method_name: str, ep: ExportedProgram 67) -> None: 68 assert isinstance(ep, ExportedProgram) 69 serialized_artifact = serialize(ep) 70 assert isinstance(serialized_artifact.exported_program, bytes) 71 etrecord_zip.writestr( 72 f"{module_name}/{method_name}", serialized_artifact.exported_program 73 ) 74 etrecord_zip.writestr( 75 f"{module_name}/{method_name}_state_dict", serialized_artifact.state_dict 76 ) 77 etrecord_zip.writestr( 78 f"{module_name}/{method_name}_constants", serialized_artifact.constants 79 ) 80 etrecord_zip.writestr( 81 f"{module_name}/{method_name}_example_inputs", 82 serialized_artifact.example_inputs, 83 ) 84 85 86def _handle_export_module( 87 etrecord_zip: ZipFile, 88 export_module: Union[ 89 ExirExportedProgram, 90 EdgeProgramManager, 91 ExportedProgram, 92 ], 93 module_name: str, 94) -> None: 95 if isinstance(export_module, ExirExportedProgram): 96 _handle_exported_program( 97 etrecord_zip, module_name, "forward", export_module.exported_program 98 ) 99 elif isinstance(export_module, ExportedProgram): 100 _handle_exported_program(etrecord_zip, module_name, "forward", export_module) 101 elif isinstance( 102 export_module, 103 (EdgeProgramManager, exir.program._program.EdgeProgramManager), 104 ): 105 for method in export_module.methods: 106 _handle_exported_program( 107 etrecord_zip, 108 module_name, 109 method, 110 export_module.exported_program(method), 111 ) 112 else: 113 raise RuntimeError(f"Unsupported graph module type. {type(export_module)}") 114 115 116def _handle_edge_dialect_exported_program( 117 etrecord_zip: ZipFile, edge_dialect_exported_program: ExportedProgram 118) -> None: 119 serialized_artifact = serialize(edge_dialect_exported_program) 120 assert isinstance(serialized_artifact.exported_program, bytes) 121 122 etrecord_zip.writestr( 123 ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM, 124 serialized_artifact.exported_program, 125 ) 126 etrecord_zip.writestr( 127 f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_state_dict", 128 serialized_artifact.state_dict, 129 ) 130 etrecord_zip.writestr( 131 f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_constants", 132 serialized_artifact.constants, 133 ) 134 etrecord_zip.writestr( 135 f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_example_inputs", 136 serialized_artifact.example_inputs, 137 ) 138 139 140def _get_reference_outputs( 141 bundled_program: BundledProgram, 142) -> Dict[str, List[ProgramOutput]]: 143 """ 144 Extracts out the expected outputs from the bundled program, keyed by the method names. 145 """ 146 reference_outputs = {} 147 for method_test_suite in bundled_program.method_test_suites: 148 reference_outputs[method_test_suite.method_name] = [] 149 for test_case in method_test_suite.test_cases: 150 if not test_case.expected_outputs: 151 raise ValueError( 152 f"Missing at least one set of expected outputs for method {method_test_suite.method_name}." 153 ) 154 reference_outputs[method_test_suite.method_name].append( 155 test_case.expected_outputs 156 ) 157 return reference_outputs 158 159 160def generate_etrecord( 161 et_record: Union[str, os.PathLike, BinaryIO, IO[bytes]], 162 edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram], 163 executorch_program: Union[ 164 ExecutorchProgram, 165 ExecutorchProgramManager, 166 BundledProgram, 167 ], 168 export_modules: Optional[ 169 Dict[ 170 str, 171 Union[ 172 ExportedProgram, 173 ExirExportedProgram, 174 EdgeProgramManager, 175 ], 176 ] 177 ] = None, 178) -> None: 179 """ 180 Generates an `ETRecord` from the given objects, serializes it and saves it to the given path. 181 The objects that will be serialized to an `ETRecord` are all the graph modules present 182 in the `export_modules` dict, the graph module present in the edge dialect program object, 183 and also the graph module present in the ExecuTorch program object, which 184 is the closest graph module representation of what is eventually run on the device. 185 In addition to all the graph modules, we also serialize the program buffer, which the users 186 can provide to the ExecuTorch runtime to run the model, and the debug handle map 187 for Developer Tools usage. 188 189 Args: 190 et_record: Path to where the `ETRecord` file will be saved to. 191 edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge() 192 executorch_program: The ExecuTorch program for this model returned by the call to `to_executorch()` or the `BundledProgram` of this model 193 export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the 194 value being the corresponding exported module. The exported graph modules can be either the 195 output of `torch.export()` or `exir.to_edge()`. 196 197 Returns: 198 None 199 """ 200 201 if isinstance(et_record, (str, os.PathLike)): 202 et_record = os.fspath(et_record) # pyre-ignore 203 204 etrecord_zip = ZipFile(et_record, "w") 205 # Write the magic file identifier that will be used to verify that this file 206 # is an etrecord when it's used later in the Developer Tools. 207 etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "") 208 209 if export_modules is not None: 210 for module_name, export_module in export_modules.items(): 211 contains_reserved_name = any( 212 reserved_name in module_name 213 for reserved_name in ETRecordReservedFileNames 214 ) 215 if contains_reserved_name: 216 raise RuntimeError( 217 f"The name {module_name} provided in the export_modules dict is a reserved name in the ETRecord namespace." 218 ) 219 _handle_export_module(etrecord_zip, export_module, module_name) 220 221 if isinstance( 222 edge_dialect_program, 223 (EdgeProgramManager, exir.program._program.EdgeProgramManager), 224 ): 225 _handle_edge_dialect_exported_program( 226 etrecord_zip, 227 edge_dialect_program.exported_program(), 228 ) 229 elif isinstance(edge_dialect_program, ExirExportedProgram): 230 _handle_edge_dialect_exported_program( 231 etrecord_zip, 232 edge_dialect_program.exported_program, 233 ) 234 else: 235 raise RuntimeError( 236 f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}." 237 ) 238 239 # When a BundledProgram is passed in, extract the reference outputs and save in a file 240 if isinstance(executorch_program, BundledProgram): 241 reference_outputs = _get_reference_outputs(executorch_program) 242 etrecord_zip.writestr( 243 ETRecordReservedFileNames.REFERENCE_OUTPUTS, 244 # @lint-ignore PYTHONPICKLEISBAD 245 pickle.dumps(reference_outputs), 246 ) 247 executorch_program = executorch_program.executorch_program 248 249 etrecord_zip.writestr( 250 ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME, 251 json.dumps(executorch_program.debug_handle_map), 252 ) 253 254 etrecord_zip.writestr( 255 ETRecordReservedFileNames.DELEGATE_MAP_NAME, 256 json.dumps(executorch_program.delegate_map), 257 ) 258 259 260def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 261 """ 262 Parses an `ETRecord` file and returns an `ETRecord` object that contains the deserialized graph 263 modules, program buffer, and a debug handle map. 264 In the graph map in the returned `ETRecord` object if a model with multiple entry points was provided 265 originally by the user during `ETRecord` generation then each entry point will be stored as a separate 266 graph module in the `ETRecord` object with the name being `the original module name + "/" + the 267 name of the entry point`. 268 269 Args: 270 etrecord_path: Path to the `ETRecord` file. 271 272 Returns: 273 `ETRecord` object. 274 """ 275 276 try: 277 etrecord_zip = ZipFile(etrecord_path, "r") 278 except BadZipFile: 279 raise RuntimeError("Invalid etrecord file passed in.") 280 281 file_list = etrecord_zip.namelist() 282 283 if ETRecordReservedFileNames.ETRECORD_IDENTIFIER not in file_list: 284 raise RuntimeError( 285 "ETRecord identifier missing from etrecord file passed in. Either an invalid file was passed in or the file is corrupt." 286 ) 287 288 graph_map: Dict[str, ExportedProgram] = {} 289 debug_handle_map = None 290 delegate_map = None 291 edge_dialect_program = None 292 reference_outputs = None 293 294 serialized_exported_program_files = set() 295 serialized_state_dict_files = set() 296 serialized_constants_files = set() 297 serialized_example_inputs_files = set() 298 for entry in file_list: 299 if entry == ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME: 300 debug_handle_map = json.loads( 301 etrecord_zip.read(ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME) 302 ) 303 elif entry == ETRecordReservedFileNames.DELEGATE_MAP_NAME: 304 delegate_map = json.loads( 305 etrecord_zip.read(ETRecordReservedFileNames.DELEGATE_MAP_NAME) 306 ) 307 elif entry == ETRecordReservedFileNames.ETRECORD_IDENTIFIER: 308 continue 309 elif entry == ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM: 310 serialized_artifact = SerializedArtifact( 311 etrecord_zip.read( 312 ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM 313 ), 314 etrecord_zip.read(f"{entry}_state_dict"), 315 etrecord_zip.read(f"{entry}_constants"), 316 etrecord_zip.read(f"{entry}_example_inputs"), 317 ) 318 edge_dialect_program = deserialize(serialized_artifact) 319 elif entry == ETRecordReservedFileNames.REFERENCE_OUTPUTS: 320 # @lint-ignore PYTHONPICKLEISBAD 321 reference_outputs = pickle.loads( 322 etrecord_zip.read(ETRecordReservedFileNames.REFERENCE_OUTPUTS) 323 ) 324 else: 325 if entry.endswith("state_dict"): 326 serialized_state_dict_files.add(entry) 327 elif entry.endswith("constants"): 328 serialized_constants_files.add(entry) 329 elif entry.endswith("example_inputs"): 330 serialized_example_inputs_files.add(entry) 331 else: 332 serialized_exported_program_files.add(entry) 333 334 for serialized_file in serialized_exported_program_files: 335 serialized_state_dict_file = f"{serialized_file}_state_dict" 336 serialized_constants_file = f"{serialized_file}_constants" 337 serialized_example_inputs_file = f"{serialized_file}_example_inputs" 338 assert ( 339 serialized_state_dict_file in serialized_state_dict_files 340 ), f"Could not find corresponding state dict file for {serialized_file}." 341 serialized_artifact = SerializedArtifact( 342 etrecord_zip.read(serialized_file), 343 etrecord_zip.read(serialized_state_dict_file), 344 etrecord_zip.read(serialized_constants_file), 345 etrecord_zip.read(serialized_example_inputs_file), 346 ) 347 graph_map[serialized_file] = deserialize(serialized_artifact) 348 349 return ETRecord( 350 edge_dialect_program=edge_dialect_program, 351 graph_map=graph_map, 352 _debug_handle_map=debug_handle_map, 353 _delegate_map=delegate_map, 354 _reference_outputs=reference_outputs, 355 ) 356