xref: /aosp_15_r20/external/executorch/examples/qualcomm/scripts/export_example.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# pyre-ignore-all-errors
2import argparse
3import copy
4
5import torch
6from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
7from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
8from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
9from executorch.backends.qualcomm.utils.utils import (
10    capture_program,
11    generate_htp_compiler_spec,
12    generate_qnn_executorch_compiler_spec,
13)
14from executorch.devtools import generate_etrecord
15from executorch.examples.models import MODEL_NAME_TO_MODEL
16from executorch.examples.models.model_factory import EagerModelFactory
17from executorch.exir.backend.backend_api import to_backend, validation_disabled
18from executorch.exir.capture._config import ExecutorchBackendConfig
19from executorch.extension.export_util.utils import save_pte_program
20
21from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
22
23
24def main() -> None:
25    parser = argparse.ArgumentParser()
26    parser.add_argument(
27        "-m",
28        "--model_name",
29        required=True,
30        help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
31    )
32    parser.add_argument(
33        "-g",
34        "--generate_etrecord",
35        action="store_true",
36        required=True,
37        help="Generate ETRecord metadata to link with runtime results (used for profiling)",
38    )
39
40    parser.add_argument(
41        "-f",
42        "--output_folder",
43        type=str,
44        default="",
45        help="The folder to store the exported program",
46    )
47
48    args = parser.parse_args()
49
50    if args.model_name not in MODEL_NAME_TO_MODEL:
51        raise RuntimeError(
52            f"Model {args.model_name} is not a valid name. "
53            f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
54        )
55
56    model, example_inputs, _, _ = EagerModelFactory.create_model(
57        *MODEL_NAME_TO_MODEL[args.model_name]
58    )
59
60    # Get quantizer
61    quantizer = QnnQuantizer()
62
63    # Typical pytorch 2.0 quantization flow
64    m = torch.export.export(model.eval(), example_inputs).module()
65    m = prepare_pt2e(m, quantizer)
66    # Calibration
67    m(*example_inputs)
68    # Get the quantized model
69    m = convert_pt2e(m)
70
71    # Capture program for edge IR
72    edge_program = capture_program(m, example_inputs)
73
74    # this is needed for the ETRecord as lowering modifies the graph in-place
75    edge_copy = copy.deepcopy(edge_program)
76
77    # Delegate to QNN backend
78    backend_options = generate_htp_compiler_spec(
79        use_fp16=False,
80    )
81    qnn_partitioner = QnnPartitioner(
82        generate_qnn_executorch_compiler_spec(
83            soc_model=QcomChipset.SM8550,
84            backend_options=backend_options,
85        )
86    )
87    with validation_disabled():
88        delegated_program = edge_program
89        delegated_program.exported_program = to_backend(
90            edge_program.exported_program, qnn_partitioner
91        )
92
93    executorch_program = delegated_program.to_executorch(
94        config=ExecutorchBackendConfig(extract_delegate_segments=False)
95    )
96
97    if args.generate_etrecord:
98        etrecord_path = args.output_folder + "etrecord.bin"
99        generate_etrecord(etrecord_path, edge_copy, executorch_program)
100
101    save_pte_program(executorch_program, args.model_name, args.output_folder)
102
103
104if __name__ == "__main__":
105    main()  # pragma: no cover
106