xref: /aosp_15_r20/external/executorch/examples/portable/scripts/export.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# Example script for exporting simple models to flatbuffer
8
9import argparse
10import logging
11
12import torch
13
14from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig
15from executorch.extension.export_util.utils import (
16    export_to_edge,
17    export_to_exec_prog,
18    save_pte_program,
19)
20
21from ...models import MODEL_NAME_TO_MODEL
22from ...models.model_factory import EagerModelFactory
23
24
25FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
26logging.basicConfig(level=logging.INFO, format=FORMAT)
27
28
29def main() -> None:
30    parser = argparse.ArgumentParser()
31    parser.add_argument(
32        "-m",
33        "--model_name",
34        required=True,
35        help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
36    )
37
38    parser.add_argument(
39        "-s",
40        "--strict",
41        action=argparse.BooleanOptionalAction,
42        help="whether to export with strict mode. Default is True",
43    )
44
45    parser.add_argument(
46        "-a",
47        "--segment_alignment",
48        required=False,
49        help="specify segment alignment in hex. Default is 0x1000. Use 0x4000 for iOS",
50    )
51    parser.add_argument("-o", "--output_dir", default=".", help="output directory")
52
53    args = parser.parse_args()
54
55    if args.model_name not in MODEL_NAME_TO_MODEL:
56        raise RuntimeError(
57            f"Model {args.model_name} is not a valid name. "
58            f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
59        )
60
61    model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model(
62        *MODEL_NAME_TO_MODEL[args.model_name]
63    )
64
65    backend_config = ExecutorchBackendConfig()
66    if args.segment_alignment is not None:
67        backend_config.segment_alignment = int(args.segment_alignment, 16)
68    if dynamic_shapes is not None:
69        edge_manager = export_to_edge(
70            model,
71            example_inputs,
72            dynamic_shapes=dynamic_shapes,
73            edge_compile_config=EdgeCompileConfig(
74                _check_ir_validity=False,
75            ),
76            strict=args.strict,
77        )
78        prog = edge_manager.to_executorch(config=backend_config)
79    else:
80        prog = export_to_exec_prog(
81            model,
82            example_inputs,
83            dynamic_shapes=dynamic_shapes,
84            backend_config=backend_config,
85            strict=args.strict,
86        )
87    save_pte_program(prog, args.model_name, args.output_dir)
88
89
90if __name__ == "__main__":
91    with torch.no_grad():
92        main()  # pragma: no cover
93