xref: /aosp_15_r20/external/executorch/extension/export_util/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import os
9
10from typing import Any, Dict, Optional, Tuple, Union
11
12import executorch.exir as exir
13
14import torch
15from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge
16from executorch.exir.tracer import Value
17from torch.export import export, export_for_training, ExportedProgram
18
19
20_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
21    _check_ir_validity=True,
22    _skip_dim_order=True,  # TODO(T189114319): Reuse dim order op after solving the ios oss issue
23)
24
25
26def _to_core_aten(
27    model: Union[torch.fx.GraphModule, torch.nn.Module],
28    example_inputs: Tuple[Value, ...],
29    *,
30    example_kwarg_inputs: Optional[Dict] = None,
31    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
32    strict=True,
33    verbose=True,
34) -> ExportedProgram:
35    # post autograd export. eventually this will become .to_core_aten
36    if not isinstance(model, torch.fx.GraphModule) and not isinstance(
37        model, torch.nn.Module
38    ):
39        raise ValueError(
40            f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
41        )
42    core_aten_ep = export(
43        model,
44        example_inputs,
45        example_kwarg_inputs,
46        dynamic_shapes=dynamic_shapes,
47        strict=strict,
48    )
49    if verbose:
50        logging.info(f"Core ATen graph:\n{core_aten_ep.graph}")
51    return core_aten_ep
52
53
54def _core_aten_to_edge(
55    core_aten_exir_ep: ExportedProgram,
56    edge_constant_methods: Optional[Dict[str, Any]] = None,
57    edge_compile_config=None,
58    verbose=True,
59) -> EdgeProgramManager:
60    if not edge_compile_config:
61        edge_compile_config = exir.EdgeCompileConfig(
62            _check_ir_validity=False,  # quant ops currently break ir verification
63            _skip_dim_order=True,  # TODO(T182928844): dim order ops can not delegate to backend
64        )
65    edge_manager: EdgeProgramManager = to_edge(
66        core_aten_exir_ep,
67        constant_methods=edge_constant_methods,
68        compile_config=edge_compile_config,
69    )
70    if verbose:
71        logging.info(f"Exported graph:\n{edge_manager.exported_program()}")
72    return edge_manager
73
74
75def export_to_edge(
76    model: Union[torch.fx.GraphModule, torch.nn.Module],
77    example_inputs: Tuple[Value, ...],
78    *,
79    example_kwarg_inputs: Optional[Dict] = None,
80    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
81    edge_constant_methods: Optional[Dict[str, Any]] = None,
82    edge_compile_config=_EDGE_COMPILE_CONFIG,
83    strict=True,
84    verbose=True,
85) -> EdgeProgramManager:
86    core_aten_ep = _to_core_aten(
87        model,
88        example_inputs,
89        example_kwarg_inputs=example_kwarg_inputs,
90        dynamic_shapes=dynamic_shapes,
91        strict=strict,
92        verbose=verbose,
93    )
94    return _core_aten_to_edge(
95        core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose
96    )
97
98
99def export_to_exec_prog(
100    model: Union[torch.fx.GraphModule, torch.nn.Module],
101    example_inputs: Tuple[Value, ...],
102    *,
103    example_kwarg_inputs: Optional[Dict[str, Any]] = None,
104    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
105    edge_constant_methods: Optional[Dict[str, Any]] = None,
106    edge_compile_config=_EDGE_COMPILE_CONFIG,
107    backend_config=None,
108    strict=True,
109) -> ExecutorchProgramManager:
110    m = model.eval()
111    # pre-autograd export. eventually this will become torch.export
112    m = export_for_training(m, example_inputs).module()
113
114    core_aten_ep = _to_core_aten(
115        m,
116        example_inputs,
117        example_kwarg_inputs=example_kwarg_inputs,
118        dynamic_shapes=dynamic_shapes,
119        strict=strict,
120    )
121
122    edge_m = _core_aten_to_edge(
123        core_aten_ep, edge_constant_methods, edge_compile_config
124    )
125
126    exec_prog = edge_m.to_executorch(backend_config)
127    return exec_prog
128
129
130def save_pte_program(
131    prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
132) -> str:
133    if model_name.endswith(".pte"):
134        filename = model_name
135    else:
136        filename = os.path.join(output_dir, f"{model_name}.pte")
137
138    try:
139        with open(filename, "wb") as file:
140            prog.write_to_file(file)
141            logging.info(f"Saved exported program to {filename}")
142    except Exception as e:
143        logging.error(f"Error while saving to {filename}: {e}")
144
145    return filename
146