xref: /aosp_15_r20/external/executorch/backends/cadence/aot/export_example.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# Example script for exporting simple models to flatbuffer
8
9import logging
10import tempfile
11
12import torch
13
14from executorch.backends.cadence.aot.ops_registrations import *  # noqa
15from typing import Any, Tuple
16
17from executorch.backends.cadence.aot.compiler import (
18    convert_pt2,
19    export_to_executorch_gen_etrecord,
20    fuse_pt2,
21)
22
23from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
24from executorch.backends.cadence.runtime import runtime
25from executorch.backends.cadence.runtime.executor import BundledProgramManager
26from executorch.exir import ExecutorchProgramManager
27from torch import nn
28from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
29from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
30    QuantizationConfig,
31    QuantizationSpec,
32)
33
34from .utils import save_bpte_program, save_pte_program
35
36
37FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
38logging.basicConfig(level=logging.INFO, format=FORMAT)
39
40act_qspec = QuantizationSpec(
41    dtype=torch.int8,
42    quant_min=-128,
43    quant_max=127,
44    qscheme=torch.per_tensor_affine,
45    is_dynamic=False,
46    observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
47)
48
49wgt_qspec = QuantizationSpec(
50    dtype=torch.int8,
51    quant_min=-128,
52    quant_max=127,
53    qscheme=torch.per_tensor_affine,
54    is_dynamic=False,
55    observer_or_fake_quant_ctr=MinMaxObserver,
56)
57
58
59def export_model(
60    model: nn.Module,
61    example_inputs: Tuple[Any, ...],
62    file_name: str = "CadenceDemoModel",
63):
64    # create work directory for outputs and model binary
65    working_dir = tempfile.mkdtemp(dir="/tmp")
66    logging.debug(f"Created work directory {working_dir}")
67
68    qconfig = QuantizationConfig(
69        act_qspec,
70        act_qspec,
71        wgt_qspec,
72        None,
73    )
74
75    # Instantiate the quantizer
76    quantizer = CadenceQuantizer(qconfig)
77
78    # Convert the model
79    converted_model = convert_pt2(model, example_inputs, quantizer)
80
81    # Get reference outputs from converted model
82    ref_outputs = converted_model(*example_inputs)
83
84    # Quantize the model (note: quantizer needs to be the same as
85    # the one used in convert_pt2)
86    quantized_model = fuse_pt2(converted_model, quantizer)
87
88    # Get edge program after Cadence specific passes
89    exec_prog: ExecutorchProgramManager = export_to_executorch_gen_etrecord(
90        quantized_model, example_inputs, output_dir=working_dir
91    )
92
93    logging.info("Final exported graph:\n")
94    exec_prog.exported_program().graph_module.graph.print_tabular()
95
96    forward_test_data = BundledProgramManager.bundled_program_test_data_gen(
97        method="forward", inputs=example_inputs, expected_outputs=ref_outputs
98    )
99    bundled_program_manager = BundledProgramManager([forward_test_data])
100    buffer = bundled_program_manager._serialize(
101        exec_prog,
102        bundled_program_manager.get_method_test_suites(),
103        forward_test_data,
104    )
105    # Save the program as pte (default name is CadenceDemoModel.pte)
106    save_pte_program(exec_prog, file_name, working_dir)
107    # Save the program as btpe (default name is CadenceDemoModel.bpte)
108    save_bpte_program(buffer, file_name, working_dir)
109
110    logging.debug(
111        f"Executorch bundled program buffer saved to {file_name} is {len(buffer)} total bytes"
112    )
113
114    # TODO: move to test infra
115    runtime.run_and_compare(
116        executorch_prog=exec_prog,
117        inputs=example_inputs,
118        ref_outputs=ref_outputs,
119        working_dir=working_dir,
120    )
121