1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# Example script for exporting simple models to flatbuffer 8 9import argparse 10import logging 11 12import torch 13 14from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig 15from executorch.extension.export_util.utils import ( 16 export_to_edge, 17 export_to_exec_prog, 18 save_pte_program, 19) 20 21from ...models import MODEL_NAME_TO_MODEL 22from ...models.model_factory import EagerModelFactory 23 24 25FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 26logging.basicConfig(level=logging.INFO, format=FORMAT) 27 28 29def main() -> None: 30 parser = argparse.ArgumentParser() 31 parser.add_argument( 32 "-m", 33 "--model_name", 34 required=True, 35 help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", 36 ) 37 38 parser.add_argument( 39 "-s", 40 "--strict", 41 action=argparse.BooleanOptionalAction, 42 help="whether to export with strict mode. Default is True", 43 ) 44 45 parser.add_argument( 46 "-a", 47 "--segment_alignment", 48 required=False, 49 help="specify segment alignment in hex. Default is 0x1000. Use 0x4000 for iOS", 50 ) 51 parser.add_argument("-o", "--output_dir", default=".", help="output directory") 52 53 args = parser.parse_args() 54 55 if args.model_name not in MODEL_NAME_TO_MODEL: 56 raise RuntimeError( 57 f"Model {args.model_name} is not a valid name. " 58 f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." 59 ) 60 61 model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model( 62 *MODEL_NAME_TO_MODEL[args.model_name] 63 ) 64 65 backend_config = ExecutorchBackendConfig() 66 if args.segment_alignment is not None: 67 backend_config.segment_alignment = int(args.segment_alignment, 16) 68 if dynamic_shapes is not None: 69 edge_manager = export_to_edge( 70 model, 71 example_inputs, 72 dynamic_shapes=dynamic_shapes, 73 edge_compile_config=EdgeCompileConfig( 74 _check_ir_validity=False, 75 ), 76 strict=args.strict, 77 ) 78 prog = edge_manager.to_executorch(config=backend_config) 79 else: 80 prog = export_to_exec_prog( 81 model, 82 example_inputs, 83 dynamic_shapes=dynamic_shapes, 84 backend_config=backend_config, 85 strict=args.strict, 86 ) 87 save_pte_program(prog, args.model_name, args.output_dir) 88 89 90if __name__ == "__main__": 91 with torch.no_grad(): 92 main() # pragma: no cover 93