xref: /aosp_15_r20/external/executorch/docs/source/tutorial-xnnpack-delegate-lowering.md (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Building and Running ExecuTorch with XNNPACK Backend
2
3The following tutorial will familiarize you with leveraging the ExecuTorch XNNPACK Delegate for accelerating your ML Models using CPU hardware. It will go over exporting and serializing a model to a binary file, targeting the XNNPACK Delegate Backend and running the model on a supported target  platform. To get started quickly, use the script in the ExecuTorch repository with instructions on exporting and generating  a binary file for a few sample models demonstrating the flow.
4
5<!----This will show a grid card on the page----->
6::::{grid} 2
7:::{grid-item-card}  What you will learn in this tutorial:
8:class-card: card-prerequisites
9In this tutorial, you will learn how to export an XNNPACK lowered Model and run it on a target platform
10:::
11:::{grid-item-card}  Before you begin it is recommended you go through the following:
12:class-card: card-prerequisites
13* [Setting up ExecuTorch](./getting-started-setup.md)
14* [Model Lowering Tutorial](./tutorials/export-to-executorch-tutorial)
15* [ExecuTorch XNNPACK Delegate](./native-delegates-executorch-xnnpack-delegate.md)
16:::
17::::
18
19
20## Lowering a model to XNNPACK
21```python
22import torch
23import torchvision.models as models
24
25from torch.export import export, ExportedProgram
26from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
27from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
28from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge_transform_and_lower
29from executorch.exir.backend.backend_api import to_backend
30
31
32mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
33sample_inputs = (torch.randn(1, 3, 224, 224), )
34
35exported_program: ExportedProgram = export(mobilenet_v2, sample_inputs)
36edge: EdgeProgramManager = to_edge_transform_and_lower(
37    exported_program,
38    partitioner=[XnnpackPartitioner()],
39)
40```
41
42We will go through this example with the [MobileNetV2](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/) pretrained model downloaded from the TorchVision library. The flow of lowering a model starts after exporting the model `to_edge`. We call the `to_backend` api with the `XnnpackPartitioner`. The partitioner identifies the subgraphs suitable for XNNPACK backend delegate to consume. Afterwards, the identified subgraphs will be serialized with the XNNPACK Delegate flatbuffer schema and each subgraph will be replaced with a call to the XNNPACK Delegate.
43
44```python
45>>> print(edge.exported_program().graph_module)
46GraphModule(
47  (lowered_module_0): LoweredBackendModule()
48  (lowered_module_1): LoweredBackendModule()
49)
50
51
52
53def forward(self, b_features_0_1_num_batches_tracked, ..., x):
54    lowered_module_0 = self.lowered_module_0
55    lowered_module_1 = self.lowered_module_1
56    executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, x);  lowered_module_1 = x = None
57    getitem_53 = executorch_call_delegate_1[0];  executorch_call_delegate_1 = None
58    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem_53, [1, 1280]);  getitem_53 = None
59    aten_clone_default = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default);  aten_view_copy_default = None
60    executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, aten_clone_default);  lowered_module_0 = aten_clone_default = None
61    getitem_52 = executorch_call_delegate[0];  executorch_call_delegate = None
62    return (getitem_52,)
63```
64
65We print the graph after lowering above to show the new nodes that were inserted to call the XNNPACK Delegate. The subgraphs which are being delegated to XNNPACK are the first argument at each call site. It can be observed that the majority of `convolution-relu-add` blocks and `linear` blocks were able to be delegated to XNNPACK. We can also see the operators which were not able to be lowered to the XNNPACK delegate, such as `clone` and `view_copy`.
66
67```python
68exec_prog = edge.to_executorch()
69
70with open("xnnpack_mobilenetv2.pte", "wb") as file:
71    exec_prog.write_to_file(file)
72```
73After lowering to the XNNPACK Program, we can then prepare it for executorch and save the model as a `.pte` file. `.pte` is a binary format that stores the serialized ExecuTorch graph.
74
75
76## Lowering a Quantized Model to XNNPACK
77The XNNPACK delegate can also execute symmetrically quantized models. To understand the quantization flow and learn how to quantize models, refer to [Custom Quantization](quantization-custom-quantization.md) note. For the sake of this tutorial, we will leverage the `quantize()` python helper function conveniently added to the `executorch/executorch/examples` folder.
78
79```python
80from torch.export import export_for_training
81from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
82
83mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
84sample_inputs = (torch.randn(1, 3, 224, 224), )
85
86mobilenet_v2 = export_for_training(mobilenet_v2, sample_inputs).module() # 2-stage export for quantization path
87
88from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
89from torch.ao.quantization.quantizer.xnnpack_quantizer import (
90    get_symmetric_quantization_config,
91    XNNPACKQuantizer,
92)
93
94
95def quantize(model, example_inputs):
96    """This is the official recommended flow for quantization in pytorch 2.0 export"""
97    print(f"Original model: {model}")
98    quantizer = XNNPACKQuantizer()
99    # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
100    operator_config = get_symmetric_quantization_config(is_per_channel=False)
101    quantizer.set_global(operator_config)
102    m = prepare_pt2e(model, quantizer)
103    # calibration
104    m(*example_inputs)
105    m = convert_pt2e(m)
106    print(f"Quantized model: {m}")
107    # make sure we can export to flat buffer
108    return m
109
110quantized_mobilenetv2 = quantize(mobilenet_v2, sample_inputs)
111```
112
113Quantization requires a two stage export. First we use the `export_for_training` API to capture the model before giving it to `quantize` utility function. After performing the quantization step, we can now leverage the XNNPACK delegate to lower the quantized exported model graph. From here, the procedure is the same as for the non-quantized model lowering to XNNPACK.
114
115```python
116# Continued from earlier...
117edge = to_edge_transform_and_lower(
118    export(quantized_mobilenetv2, sample_inputs),
119    compile_config=EdgeCompileConfig(_check_ir_validity=False),
120    partitioner=[XnnpackPartitioner()]
121)
122
123exec_prog = edge.to_executorch()
124
125with open("qs8_xnnpack_mobilenetv2.pte", "wb") as file:
126    exec_prog.write_to_file(file)
127```
128
129## Lowering with `aot_compiler.py` script
130We have also provided a script to quickly lower and export a few example models. You can run the script to generate lowered fp32 and quantized models. This script is used simply for convenience and performs all the same steps as those listed in the previous two sections.
131
132```
133python -m examples.xnnpack.aot_compiler --model_name="mv2" --quantize --delegate
134```
135
136Note in the example above,
137* the `-—model_name` specifies the model to use
138* the `-—quantize` flag controls whether the model should be quantized or not
139* the `-—delegate` flag controls whether we attempt to lower parts of the graph to the XNNPACK delegate.
140
141The generated model file will be named `[model_name]_xnnpack_[qs8/fp32].pte` depending on the arguments supplied.
142
143## Running the XNNPACK Model with CMake
144After exporting the XNNPACK Delegated model, we can now try running it with example inputs using CMake. We can build and use the xnn_executor_runner, which is a sample wrapper for the ExecuTorch Runtime and XNNPACK Backend. We first begin by configuring the CMake build like such:
145```bash
146# cd to the root of executorch repo
147cd executorch
148
149# Get a clean cmake-out directory
150rm -rf cmake-out
151mkdir cmake-out
152
153# Configure cmake
154cmake \
155    -DCMAKE_INSTALL_PREFIX=cmake-out \
156    -DCMAKE_BUILD_TYPE=Release \
157    -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
158    -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
159    -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
160    -DEXECUTORCH_BUILD_XNNPACK=ON \
161    -DEXECUTORCH_ENABLE_LOGGING=ON \
162    -DPYTHON_EXECUTABLE=python \
163    -Bcmake-out .
164```
165Then you can build the runtime componenets with
166
167```bash
168cmake --build cmake-out -j9 --target install --config Release
169```
170
171Now you should be able to find the executable built at `./cmake-out/backends/xnnpack/xnn_executor_runner` you can run the executable with the model you generated as such
172```bash
173./cmake-out/backends/xnnpack/xnn_executor_runner --model_path=./mv2_xnnpack_fp32.pte
174# or to run the quantized variant
175./cmake-out/backends/xnnpack/xnn_executor_runner --model_path=./mv2_xnnpack_q8.pte
176```
177
178## Building and Linking with the XNNPACK Backend
179You can build the XNNPACK backend [CMake target](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/CMakeLists.txt#L83), and link it with your application binary such as an Android or iOS application. For more information on this you may take a look at this [resource](demo-apps-android.md) next.
180