1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8import os 9 10from typing import Any, Dict, Optional, Tuple, Union 11 12import executorch.exir as exir 13 14import torch 15from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge 16from executorch.exir.tracer import Value 17from torch.export import export, export_for_training, ExportedProgram 18 19 20_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig( 21 _check_ir_validity=True, 22 _skip_dim_order=True, # TODO(T189114319): Reuse dim order op after solving the ios oss issue 23) 24 25 26def _to_core_aten( 27 model: Union[torch.fx.GraphModule, torch.nn.Module], 28 example_inputs: Tuple[Value, ...], 29 *, 30 example_kwarg_inputs: Optional[Dict] = None, 31 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, 32 strict=True, 33 verbose=True, 34) -> ExportedProgram: 35 # post autograd export. eventually this will become .to_core_aten 36 if not isinstance(model, torch.fx.GraphModule) and not isinstance( 37 model, torch.nn.Module 38 ): 39 raise ValueError( 40 f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}" 41 ) 42 core_aten_ep = export( 43 model, 44 example_inputs, 45 example_kwarg_inputs, 46 dynamic_shapes=dynamic_shapes, 47 strict=strict, 48 ) 49 if verbose: 50 logging.info(f"Core ATen graph:\n{core_aten_ep.graph}") 51 return core_aten_ep 52 53 54def _core_aten_to_edge( 55 core_aten_exir_ep: ExportedProgram, 56 edge_constant_methods: Optional[Dict[str, Any]] = None, 57 edge_compile_config=None, 58 verbose=True, 59) -> EdgeProgramManager: 60 if not edge_compile_config: 61 edge_compile_config = exir.EdgeCompileConfig( 62 _check_ir_validity=False, # quant ops currently break ir verification 63 _skip_dim_order=True, # TODO(T182928844): dim order ops can not delegate to backend 64 ) 65 edge_manager: EdgeProgramManager = to_edge( 66 core_aten_exir_ep, 67 constant_methods=edge_constant_methods, 68 compile_config=edge_compile_config, 69 ) 70 if verbose: 71 logging.info(f"Exported graph:\n{edge_manager.exported_program()}") 72 return edge_manager 73 74 75def export_to_edge( 76 model: Union[torch.fx.GraphModule, torch.nn.Module], 77 example_inputs: Tuple[Value, ...], 78 *, 79 example_kwarg_inputs: Optional[Dict] = None, 80 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, 81 edge_constant_methods: Optional[Dict[str, Any]] = None, 82 edge_compile_config=_EDGE_COMPILE_CONFIG, 83 strict=True, 84 verbose=True, 85) -> EdgeProgramManager: 86 core_aten_ep = _to_core_aten( 87 model, 88 example_inputs, 89 example_kwarg_inputs=example_kwarg_inputs, 90 dynamic_shapes=dynamic_shapes, 91 strict=strict, 92 verbose=verbose, 93 ) 94 return _core_aten_to_edge( 95 core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose 96 ) 97 98 99def export_to_exec_prog( 100 model: Union[torch.fx.GraphModule, torch.nn.Module], 101 example_inputs: Tuple[Value, ...], 102 *, 103 example_kwarg_inputs: Optional[Dict[str, Any]] = None, 104 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, 105 edge_constant_methods: Optional[Dict[str, Any]] = None, 106 edge_compile_config=_EDGE_COMPILE_CONFIG, 107 backend_config=None, 108 strict=True, 109) -> ExecutorchProgramManager: 110 m = model.eval() 111 # pre-autograd export. eventually this will become torch.export 112 m = export_for_training(m, example_inputs).module() 113 114 core_aten_ep = _to_core_aten( 115 m, 116 example_inputs, 117 example_kwarg_inputs=example_kwarg_inputs, 118 dynamic_shapes=dynamic_shapes, 119 strict=strict, 120 ) 121 122 edge_m = _core_aten_to_edge( 123 core_aten_ep, edge_constant_methods, edge_compile_config 124 ) 125 126 exec_prog = edge_m.to_executorch(backend_config) 127 return exec_prog 128 129 130def save_pte_program( 131 prog: ExecutorchProgramManager, model_name: str, output_dir: str = "" 132) -> str: 133 if model_name.endswith(".pte"): 134 filename = model_name 135 else: 136 filename = os.path.join(output_dir, f"{model_name}.pte") 137 138 try: 139 with open(filename, "wb") as file: 140 prog.write_to_file(file) 141 logging.info(f"Saved exported program to {filename}") 142 except Exception as e: 143 logging.error(f"Error while saving to {filename}: {e}") 144 145 return filename 146