xref: /aosp_15_r20/external/executorch/devtools/etrecord/_etrecord.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import json
10import os
11import pickle
12from dataclasses import dataclass
13from typing import BinaryIO, Dict, IO, List, Optional, Union
14from zipfile import BadZipFile, ZipFile
15
16from executorch import exir
17from executorch.devtools.bundled_program.core import BundledProgram
18
19from executorch.devtools.bundled_program.schema.bundled_program_schema import Value
20from executorch.exir import (
21    EdgeProgramManager,
22    ExecutorchProgram,
23    ExecutorchProgramManager,
24    ExirExportedProgram,
25    ExportedProgram,
26)
27from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
28
29from executorch.exir.serde.export_serialize import SerializedArtifact
30from executorch.exir.serde.serialize import deserialize, serialize
31
32ProgramOutput = List[Value]
33
34try:
35    # breaking change introduced in python 3.11
36    # pyre-ignore
37    from enum import StrEnum
38except ImportError:
39    from enum import Enum
40
41    class StrEnum(str, Enum):
42        pass
43
44
45class ETRecordReservedFileNames(StrEnum):
46    ETRECORD_IDENTIFIER = "ETRECORD_V0"
47    EDGE_DIALECT_EXPORTED_PROGRAM = "edge_dialect_exported_program"
48    ET_DIALECT_GRAPH_MODULE = "et_dialect_graph_module"
49    DEBUG_HANDLE_MAP_NAME = "debug_handle_map"
50    DELEGATE_MAP_NAME = "delegate_map"
51    REFERENCE_OUTPUTS = "reference_outputs"
52
53
54@dataclass
55class ETRecord:
56    edge_dialect_program: Optional[ExportedProgram] = None
57    graph_map: Optional[Dict[str, ExportedProgram]] = None
58    _debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None
59    _delegate_map: Optional[
60        Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]
61    ] = None
62    _reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None
63
64
65def _handle_exported_program(
66    etrecord_zip: ZipFile, module_name: str, method_name: str, ep: ExportedProgram
67) -> None:
68    assert isinstance(ep, ExportedProgram)
69    serialized_artifact = serialize(ep)
70    assert isinstance(serialized_artifact.exported_program, bytes)
71    etrecord_zip.writestr(
72        f"{module_name}/{method_name}", serialized_artifact.exported_program
73    )
74    etrecord_zip.writestr(
75        f"{module_name}/{method_name}_state_dict", serialized_artifact.state_dict
76    )
77    etrecord_zip.writestr(
78        f"{module_name}/{method_name}_constants", serialized_artifact.constants
79    )
80    etrecord_zip.writestr(
81        f"{module_name}/{method_name}_example_inputs",
82        serialized_artifact.example_inputs,
83    )
84
85
86def _handle_export_module(
87    etrecord_zip: ZipFile,
88    export_module: Union[
89        ExirExportedProgram,
90        EdgeProgramManager,
91        ExportedProgram,
92    ],
93    module_name: str,
94) -> None:
95    if isinstance(export_module, ExirExportedProgram):
96        _handle_exported_program(
97            etrecord_zip, module_name, "forward", export_module.exported_program
98        )
99    elif isinstance(export_module, ExportedProgram):
100        _handle_exported_program(etrecord_zip, module_name, "forward", export_module)
101    elif isinstance(
102        export_module,
103        (EdgeProgramManager, exir.program._program.EdgeProgramManager),
104    ):
105        for method in export_module.methods:
106            _handle_exported_program(
107                etrecord_zip,
108                module_name,
109                method,
110                export_module.exported_program(method),
111            )
112    else:
113        raise RuntimeError(f"Unsupported graph module type. {type(export_module)}")
114
115
116def _handle_edge_dialect_exported_program(
117    etrecord_zip: ZipFile, edge_dialect_exported_program: ExportedProgram
118) -> None:
119    serialized_artifact = serialize(edge_dialect_exported_program)
120    assert isinstance(serialized_artifact.exported_program, bytes)
121
122    etrecord_zip.writestr(
123        ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM,
124        serialized_artifact.exported_program,
125    )
126    etrecord_zip.writestr(
127        f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_state_dict",
128        serialized_artifact.state_dict,
129    )
130    etrecord_zip.writestr(
131        f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_constants",
132        serialized_artifact.constants,
133    )
134    etrecord_zip.writestr(
135        f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_example_inputs",
136        serialized_artifact.example_inputs,
137    )
138
139
140def _get_reference_outputs(
141    bundled_program: BundledProgram,
142) -> Dict[str, List[ProgramOutput]]:
143    """
144    Extracts out the expected outputs from the bundled program, keyed by the method names.
145    """
146    reference_outputs = {}
147    for method_test_suite in bundled_program.method_test_suites:
148        reference_outputs[method_test_suite.method_name] = []
149        for test_case in method_test_suite.test_cases:
150            if not test_case.expected_outputs:
151                raise ValueError(
152                    f"Missing at least one set of expected outputs for method {method_test_suite.method_name}."
153                )
154            reference_outputs[method_test_suite.method_name].append(
155                test_case.expected_outputs
156            )
157    return reference_outputs
158
159
160def generate_etrecord(
161    et_record: Union[str, os.PathLike, BinaryIO, IO[bytes]],
162    edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram],
163    executorch_program: Union[
164        ExecutorchProgram,
165        ExecutorchProgramManager,
166        BundledProgram,
167    ],
168    export_modules: Optional[
169        Dict[
170            str,
171            Union[
172                ExportedProgram,
173                ExirExportedProgram,
174                EdgeProgramManager,
175            ],
176        ]
177    ] = None,
178) -> None:
179    """
180    Generates an `ETRecord` from the given objects, serializes it and saves it to the given path.
181    The objects that will be serialized to an `ETRecord` are all the graph modules present
182    in the `export_modules` dict, the graph module present in the edge dialect program object,
183    and also the graph module present in the ExecuTorch program object, which
184    is the closest graph module representation of what is eventually run on the device.
185    In addition to all the graph modules, we also serialize the program buffer, which the users
186    can provide to the ExecuTorch runtime to run the model, and the debug handle map
187    for Developer Tools usage.
188
189    Args:
190        et_record: Path to where the `ETRecord` file will be saved to.
191        edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge()
192        executorch_program: The ExecuTorch program for this model returned by the call to `to_executorch()` or the `BundledProgram` of this model
193        export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the
194            value being the corresponding exported module. The exported graph modules can be either the
195            output of `torch.export()` or `exir.to_edge()`.
196
197    Returns:
198        None
199    """
200
201    if isinstance(et_record, (str, os.PathLike)):
202        et_record = os.fspath(et_record)  # pyre-ignore
203
204    etrecord_zip = ZipFile(et_record, "w")
205    # Write the magic file identifier that will be used to verify that this file
206    # is an etrecord when it's used later in the Developer Tools.
207    etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "")
208
209    if export_modules is not None:
210        for module_name, export_module in export_modules.items():
211            contains_reserved_name = any(
212                reserved_name in module_name
213                for reserved_name in ETRecordReservedFileNames
214            )
215            if contains_reserved_name:
216                raise RuntimeError(
217                    f"The name {module_name} provided in the export_modules dict is a reserved name in the ETRecord namespace."
218                )
219            _handle_export_module(etrecord_zip, export_module, module_name)
220
221    if isinstance(
222        edge_dialect_program,
223        (EdgeProgramManager, exir.program._program.EdgeProgramManager),
224    ):
225        _handle_edge_dialect_exported_program(
226            etrecord_zip,
227            edge_dialect_program.exported_program(),
228        )
229    elif isinstance(edge_dialect_program, ExirExportedProgram):
230        _handle_edge_dialect_exported_program(
231            etrecord_zip,
232            edge_dialect_program.exported_program,
233        )
234    else:
235        raise RuntimeError(
236            f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}."
237        )
238
239    # When a BundledProgram is passed in, extract the reference outputs and save in a file
240    if isinstance(executorch_program, BundledProgram):
241        reference_outputs = _get_reference_outputs(executorch_program)
242        etrecord_zip.writestr(
243            ETRecordReservedFileNames.REFERENCE_OUTPUTS,
244            # @lint-ignore PYTHONPICKLEISBAD
245            pickle.dumps(reference_outputs),
246        )
247        executorch_program = executorch_program.executorch_program
248
249    etrecord_zip.writestr(
250        ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME,
251        json.dumps(executorch_program.debug_handle_map),
252    )
253
254    etrecord_zip.writestr(
255        ETRecordReservedFileNames.DELEGATE_MAP_NAME,
256        json.dumps(executorch_program.delegate_map),
257    )
258
259
260def parse_etrecord(etrecord_path: str) -> ETRecord:  # noqa: C901
261    """
262    Parses an `ETRecord` file and returns an `ETRecord` object that contains the deserialized graph
263    modules, program buffer, and a debug handle map.
264    In the graph map in the returned `ETRecord` object if a model with multiple entry points was provided
265    originally by the user during `ETRecord` generation then each entry point will be stored as a separate
266    graph module in the `ETRecord` object with the name being `the original module name + "/" + the
267    name of the entry point`.
268
269    Args:
270        etrecord_path: Path to the `ETRecord` file.
271
272    Returns:
273        `ETRecord` object.
274    """
275
276    try:
277        etrecord_zip = ZipFile(etrecord_path, "r")
278    except BadZipFile:
279        raise RuntimeError("Invalid etrecord file passed in.")
280
281    file_list = etrecord_zip.namelist()
282
283    if ETRecordReservedFileNames.ETRECORD_IDENTIFIER not in file_list:
284        raise RuntimeError(
285            "ETRecord identifier missing from etrecord file passed in. Either an invalid file was passed in or the file is corrupt."
286        )
287
288    graph_map: Dict[str, ExportedProgram] = {}
289    debug_handle_map = None
290    delegate_map = None
291    edge_dialect_program = None
292    reference_outputs = None
293
294    serialized_exported_program_files = set()
295    serialized_state_dict_files = set()
296    serialized_constants_files = set()
297    serialized_example_inputs_files = set()
298    for entry in file_list:
299        if entry == ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME:
300            debug_handle_map = json.loads(
301                etrecord_zip.read(ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME)
302            )
303        elif entry == ETRecordReservedFileNames.DELEGATE_MAP_NAME:
304            delegate_map = json.loads(
305                etrecord_zip.read(ETRecordReservedFileNames.DELEGATE_MAP_NAME)
306            )
307        elif entry == ETRecordReservedFileNames.ETRECORD_IDENTIFIER:
308            continue
309        elif entry == ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM:
310            serialized_artifact = SerializedArtifact(
311                etrecord_zip.read(
312                    ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM
313                ),
314                etrecord_zip.read(f"{entry}_state_dict"),
315                etrecord_zip.read(f"{entry}_constants"),
316                etrecord_zip.read(f"{entry}_example_inputs"),
317            )
318            edge_dialect_program = deserialize(serialized_artifact)
319        elif entry == ETRecordReservedFileNames.REFERENCE_OUTPUTS:
320            # @lint-ignore PYTHONPICKLEISBAD
321            reference_outputs = pickle.loads(
322                etrecord_zip.read(ETRecordReservedFileNames.REFERENCE_OUTPUTS)
323            )
324        else:
325            if entry.endswith("state_dict"):
326                serialized_state_dict_files.add(entry)
327            elif entry.endswith("constants"):
328                serialized_constants_files.add(entry)
329            elif entry.endswith("example_inputs"):
330                serialized_example_inputs_files.add(entry)
331            else:
332                serialized_exported_program_files.add(entry)
333
334    for serialized_file in serialized_exported_program_files:
335        serialized_state_dict_file = f"{serialized_file}_state_dict"
336        serialized_constants_file = f"{serialized_file}_constants"
337        serialized_example_inputs_file = f"{serialized_file}_example_inputs"
338        assert (
339            serialized_state_dict_file in serialized_state_dict_files
340        ), f"Could not find corresponding state dict file for {serialized_file}."
341        serialized_artifact = SerializedArtifact(
342            etrecord_zip.read(serialized_file),
343            etrecord_zip.read(serialized_state_dict_file),
344            etrecord_zip.read(serialized_constants_file),
345            etrecord_zip.read(serialized_example_inputs_file),
346        )
347        graph_map[serialized_file] = deserialize(serialized_artifact)
348
349    return ETRecord(
350        edge_dialect_program=edge_dialect_program,
351        graph_map=graph_map,
352        _debug_handle_map=debug_handle_map,
353        _delegate_map=delegate_map,
354        _reference_outputs=reference_outputs,
355    )
356