xref: /aosp_15_r20/external/executorch/examples/xnnpack/aot_compiler.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9# Example script for exporting simple models to flatbuffer
10
11import argparse
12import copy
13import logging
14
15import torch
16from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
17from executorch.devtools import generate_etrecord
18from executorch.exir import (
19    EdgeCompileConfig,
20    ExecutorchBackendConfig,
21    to_edge_transform_and_lower,
22)
23from executorch.extension.export_util.utils import save_pte_program
24
25from ..models import MODEL_NAME_TO_MODEL
26from ..models.model_factory import EagerModelFactory
27from . import MODEL_NAME_TO_OPTIONS
28from .quantization.utils import quantize
29
30
31FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
32logging.basicConfig(level=logging.INFO, format=FORMAT)
33
34
35if __name__ == "__main__":
36    parser = argparse.ArgumentParser()
37    parser.add_argument(
38        "-m",
39        "--model_name",
40        required=True,
41        help=f"Model name. Valid ones: {list(MODEL_NAME_TO_OPTIONS.keys())}",
42    )
43    parser.add_argument(
44        "-q",
45        "--quantize",
46        action="store_true",
47        required=False,
48        default=False,
49        help="Produce an 8-bit quantized model",
50    )
51    parser.add_argument(
52        "-d",
53        "--delegate",
54        action="store_true",
55        required=False,
56        default=True,
57        help="Produce an XNNPACK delegated model",
58    )
59    parser.add_argument(
60        "-r",
61        "--etrecord",
62        required=False,
63        help="Generate and save an ETRecord to the given file location",
64    )
65    parser.add_argument("-o", "--output_dir", default=".", help="output directory")
66
67    args = parser.parse_args()
68
69    if not args.delegate:
70        raise NotImplementedError(
71            "T161880157: Quantization-only without delegation is not supported yet"
72        )
73
74    if args.model_name not in MODEL_NAME_TO_OPTIONS and args.quantize:
75        raise RuntimeError(
76            f"Model {args.model_name} is not a valid name. or not quantizable right now, "
77            "please contact executorch team if you want to learn why or how to support "
78            "quantization for the requested model"
79            f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}."
80        )
81
82    model, example_inputs, _, _ = EagerModelFactory.create_model(
83        *MODEL_NAME_TO_MODEL[args.model_name]
84    )
85
86    model = model.eval()
87    # pre-autograd export. eventually this will become torch.export
88    ep = torch.export.export_for_training(model, example_inputs)
89    model = ep.module()
90
91    if args.quantize:
92        logging.info("Quantizing Model...")
93        # TODO(T165162973): This pass shall eventually be folded into quantizer
94        model = quantize(model, example_inputs)
95
96    edge = to_edge_transform_and_lower(
97        ep,
98        partitioner=[XnnpackPartitioner()],
99        compile_config=EdgeCompileConfig(
100            _check_ir_validity=False if args.quantize else True,
101            _skip_dim_order=True,  # TODO(T182187531): enable dim order in xnnpack
102        ),
103    )
104    logging.info(f"Exported and lowered graph:\n{edge.exported_program().graph}")
105
106    # this is needed for the ETRecord as lowering modifies the graph in-place
107    edge_copy = copy.deepcopy(edge)
108
109    exec_prog = edge.to_executorch(
110        config=ExecutorchBackendConfig(extract_delegate_segments=False)
111    )
112
113    if args.etrecord is not None:
114        generate_etrecord(args.etrecord, edge_copy, exec_prog)
115        logging.info(f"Saved ETRecord to {args.etrecord}")
116
117    quant_tag = "q8" if args.quantize else "fp32"
118    model_name = f"{args.model_name}_xnnpack_{quant_tag}"
119    save_pte_program(exec_prog, model_name, args.output_dir)
120